diff --git a/.gitignore b/.gitignore index 46b8e3a47c618..8560650344b95 100644 --- a/.gitignore +++ b/.gitignore @@ -97,5 +97,6 @@ python/paddle/incubate/fleet/parameter_server/pslib/ps_pb2.py paddle/phi/kernels/fusion/cutlass/conv2d/generated/* python/paddle/base/incubate/fleet/parameter_server/pslib/ps_pb2.py paddle/fluid/ir_adaptor/translator/op_compat_info.cc +paddle/phi/kernels/fusion/cutlass/cutlass_kernels/fpA_intB_gemm/autogen/* paddle/fluid/pybind/static_op_function.* paddle/fluid/pybind/ops_api.cc diff --git a/CODE_OF_CONDUCT.md b/CODE_OF_CONDUCT.md index 54131b48eca46..f5e6a6e426c60 100644 --- a/CODE_OF_CONDUCT.md +++ b/CODE_OF_CONDUCT.md @@ -34,7 +34,7 @@ This Code of Conduct applies both within project spaces and in public spaces whe ## Enforcement -Instances of abusive, harassing, or otherwise unacceptable behavior may be reported by contacting the project team at paddle-dev@baidu.com. The project team will review and investigate all complaints, and will respond in a way that it deems appropriate to the circumstances. The project team is obligated to maintain confidentiality with regard to the reporter of an incident. Further details of specific enforcement policies may be posted separately. +Instances of abusive, harassing, or otherwise unacceptable behavior may be reported by contacting the project team at ext_paddle_oss@baidu.com. The project team will review and investigate all complaints, and will respond in a way that it deems appropriate to the circumstances. The project team is obligated to maintain confidentiality with regard to the reporter of an incident. Further details of specific enforcement policies may be posted separately. Project maintainers who do not follow or enforce the Code of Conduct in good faith may face temporary or permanent repercussions as determined by other members of the project's leadership. diff --git a/CODE_OF_CONDUCT_cn.md b/CODE_OF_CONDUCT_cn.md index 2be794f1f324c..92153a4dadcbe 100644 --- a/CODE_OF_CONDUCT_cn.md +++ b/CODE_OF_CONDUCT_cn.md @@ -36,7 +36,7 @@ ## 强制执行 -可以通过paddle-dev@baidu.com,来联系项目团队来举报滥用、骚扰或其他不被接受的行为。 +可以通过ext_paddle_oss@baidu.com,来联系项目团队来举报滥用、骚扰或其他不被接受的行为。 任何维护团队认为有必要且适合的所有投诉都将进行审查及调查,并做出相对应的回应。项目小组有对事件回报者有保密的义务。具体执行的方针近一步细节可能会单独公布。 diff --git a/cmake/cinn.cmake b/cmake/cinn.cmake index a8ebe6a9a46ae..44d502fc4b792 100644 --- a/cmake/cinn.cmake +++ b/cmake/cinn.cmake @@ -164,8 +164,8 @@ cinn_cc_library( add_dependencies(cinnapi GEN_LLVM_RUNTIME_IR_HEADER ZLIB::ZLIB) add_dependencies(cinnapi GEN_LLVM_RUNTIME_IR_HEADER ${core_deps}) if(NOT CINN_ONLY) - target_link_libraries(cinnapi pd_dialect phi) - add_dependencies(cinnapi pd_dialect phi) + target_link_libraries(cinnapi pd_op_dialect phi) + add_dependencies(cinnapi pd_op_dialect phi) endif() target_link_libraries(cinnapi ${PYTHON_LIBRARIES}) @@ -222,8 +222,8 @@ function(gen_cinncore LINKTYPE) add_dependencies(${CINNCORE_TARGET} GEN_LLVM_RUNTIME_IR_HEADER ZLIB::ZLIB) add_dependencies(${CINNCORE_TARGET} GEN_LLVM_RUNTIME_IR_HEADER ${core_deps}) if(NOT CINN_ONLY) - target_link_libraries(${CINNCORE_TARGET} pd_dialect phi) - add_dependencies(${CINNCORE_TARGET} pd_dialect phi) + target_link_libraries(${CINNCORE_TARGET} pd_op_dialect phi) + add_dependencies(${CINNCORE_TARGET} pd_op_dialect phi) endif() add_dependencies(${CINNCORE_TARGET} pybind) diff --git a/cmake/external/brpc.cmake b/cmake/external/brpc.cmake index 3c9f2b6962048..d647e9116b586 100755 --- a/cmake/external/brpc.cmake +++ b/cmake/external/brpc.cmake @@ -13,7 +13,7 @@ # limitations under the License. include(ExternalProject) - +set(OPENSSL_USE_STATIC_LIBS ON) find_package(OpenSSL REQUIRED) message(STATUS "ssl:" ${OPENSSL_SSL_LIBRARY}) diff --git a/cmake/inference_lib.cmake b/cmake/inference_lib.cmake index 13fce9613650f..f73b20d389ef4 100755 --- a/cmake/inference_lib.cmake +++ b/cmake/inference_lib.cmake @@ -269,10 +269,10 @@ else() DSTS ${PADDLE_INFERENCE_INSTALL_DIR}/paddle/lib) endif() if(WITH_SHARED_IR) - set(paddle_ir_lib ${PADDLE_BINARY_DIR}/paddle/ir/libir.*) + set(paddle_pir_lib ${PADDLE_BINARY_DIR}/paddle/pir/libpir.*) copy( inference_lib_dist - SRCS ${paddle_ir_lib} + SRCS ${paddle_pir_lib} DSTS ${PADDLE_INFERENCE_INSTALL_DIR}/paddle/lib) endif() endif() diff --git a/paddle/CMakeLists.txt b/paddle/CMakeLists.txt index 92e302eb15acc..b5f2ffa394a89 100644 --- a/paddle/CMakeLists.txt +++ b/paddle/CMakeLists.txt @@ -3,7 +3,7 @@ set(PYTHON_TESTS_DIR CACHE INTERNAL "python tests directory") add_subdirectory(utils) -add_subdirectory(ir) +add_subdirectory(pir) add_subdirectory(scripts) add_subdirectory(testing) add_subdirectory(phi) diff --git a/paddle/cinn/CMakeLists.txt b/paddle/cinn/CMakeLists.txt index 4645ff2c06636..0f0f7beed265a 100644 --- a/paddle/cinn/CMakeLists.txt +++ b/paddle/cinn/CMakeLists.txt @@ -3,6 +3,7 @@ if(WITH_TESTING) endif() add_subdirectory(api) +add_subdirectory(ast_gen_ius) add_subdirectory(auto_schedule) add_subdirectory(common) add_subdirectory(utils) diff --git a/paddle/cinn/ast_gen_ius/CMakeLists.txt b/paddle/cinn/ast_gen_ius/CMakeLists.txt new file mode 100644 index 0000000000000..c3908dfed2537 --- /dev/null +++ b/paddle/cinn/ast_gen_ius/CMakeLists.txt @@ -0,0 +1,6 @@ +core_gather_headers() + +gather_srcs(cinnapi_src SRCS ast_gen.cc tensor_group.cc) + +cinn_cc_test(test_ast_gen_ius SRCS ast_gen_test.cc DEPS cinncore) +cinn_cc_test(test_tensor_group SRCS tensor_group_test.cc DEPS cinncore) diff --git a/paddle/cinn/ast_gen_ius/ast_gen.cc b/paddle/cinn/ast_gen_ius/ast_gen.cc new file mode 100644 index 0000000000000..d10560209e6ae --- /dev/null +++ b/paddle/cinn/ast_gen_ius/ast_gen.cc @@ -0,0 +1,52 @@ +// Copyright (c) 2023 CINN Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/cinn/ast_gen_ius/ast_gen.h" +#include "paddle/cinn/ir/ir.h" +#include "paddle/cinn/ir/ir_base.h" +#include "paddle/cinn/ir/operation.h" +#include "paddle/cinn/ir/tensor.h" +#include "paddle/cinn/ir/utils/ir_printer.h" + +namespace cinn { +namespace ast_gen_ius { + +ir::Expr AstGen::Build(const ir::Tensor& tensor) { + const std::vector& axis = tensor->axis(); + const std::vector& shape = tensor->shape; + size_t axis_len = axis.size(); + CHECK_EQ(shape.size(), axis_len) + << "Internal Error: Tensor has different shape and axis length in AstGen"; + + std::vector axis_exprs; + for (const auto& a : axis) { + axis_exprs.push_back(a); + } + ir::Expr body = ir::Store::Make(tensor, tensor->body(), axis_exprs); + + for (int i = static_cast(axis_len) - 1; i >= 0; --i) { + ir::Var loop_var = axis[i]; + ir::Expr loop_extent = shape[i]; + body = ir::For::Make(loop_var, + Expr(0), + loop_extent, + ir::ForType::Serial, + ir::DeviceAPI::Host, + ir::Block::Make({body})); + } + return body; +} + +} // namespace ast_gen_ius +} // namespace cinn diff --git a/paddle/cinn/optim/remove_nested_block.h b/paddle/cinn/ast_gen_ius/ast_gen.h similarity index 66% rename from paddle/cinn/optim/remove_nested_block.h rename to paddle/cinn/ast_gen_ius/ast_gen.h index 41220c18b254a..2e9dc7fde8d8e 100644 --- a/paddle/cinn/optim/remove_nested_block.h +++ b/paddle/cinn/ast_gen_ius/ast_gen.h @@ -1,4 +1,4 @@ -// Copyright (c) 2021 CINN Authors. All Rights Reserved. +// Copyright (c) 2023 CINN Authors. All Rights Reserved. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. @@ -12,22 +12,19 @@ // See the License for the specific language governing permissions and // limitations under the License. -/** - * This file implements the strategy to remove the unnecessary nested block. - */ #pragma once -#include -#include "paddle/cinn/common/common.h" #include "paddle/cinn/ir/ir.h" +#include "paddle/cinn/ir/ir_base.h" +#include "paddle/cinn/ir/tensor.h" namespace cinn { -namespace optim { +namespace ast_gen_ius { -/** - * Remove the unecessary nested block. - */ -void RemoveNestedBlock(Expr* e); +class AstGen { + public: + static ir::Expr Build(const ir::Tensor& tensor); +}; -} // namespace optim +} // namespace ast_gen_ius } // namespace cinn diff --git a/paddle/cinn/ast_gen_ius/ast_gen_test.cc b/paddle/cinn/ast_gen_ius/ast_gen_test.cc new file mode 100644 index 0000000000000..e91c0f4ca0e28 --- /dev/null +++ b/paddle/cinn/ast_gen_ius/ast_gen_test.cc @@ -0,0 +1,44 @@ +// Copyright (c) 2023 CINN Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include +#include + +#include "paddle/cinn/ast_gen_ius/ast_gen.h" +#include "paddle/cinn/ir/ir.h" +#include "paddle/cinn/ir/ir_base.h" +#include "paddle/cinn/ir/tensor.h" +#include "paddle/cinn/lang/builtin.h" +#include "paddle/cinn/lang/compute.h" +#include "paddle/cinn/lang/placeholder.h" + +namespace cinn { +namespace ast_gen_ius { + +using cinn::ir::Expr; +using cinn::ir::Tensor; + +TEST(AstGen, Build) { + std::vector shape = {Expr(10), Expr(10), Expr(10), Expr(10)}; + lang::Placeholder A("A", shape); + Tensor B = lang::Compute( + shape, + [&](const std::vector& indice) { return lang::Relu(A(indice), 0); }, + "relu_test"); + Expr out = AstGen::Build(B); + LOG(INFO) << out; +} + +} // namespace ast_gen_ius +} // namespace cinn diff --git a/paddle/cinn/ast_gen_ius/tensor_group.cc b/paddle/cinn/ast_gen_ius/tensor_group.cc new file mode 100644 index 0000000000000..cca8b4136ba1b --- /dev/null +++ b/paddle/cinn/ast_gen_ius/tensor_group.cc @@ -0,0 +1,198 @@ +// Copyright (c) 2023 CINN Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/cinn/ast_gen_ius/tensor_group.h" + +#include +#include + +#include "paddle/cinn/ir/ir.h" +#include "paddle/cinn/ir/ir_base.h" +#include "paddle/cinn/ir/tensor.h" +#include "paddle/cinn/ir/utils/ir_nodes_collector.h" + +namespace cinn { +namespace ast_gen_ius { + +TensorGroup::TensorGroup(const std::vector& tensors) { + std::set all_tensors(tensors.begin(), tensors.end()); + + for (auto& tensor : tensors) { + output_tensor_names_.insert(tensor->name); + std::set used_tensors = ir::CollectIRNodes( + tensor->body(), [](const Expr* x) { return x->as_tensor(); }); + for (const Expr& x : used_tensors) { + const ir::Tensor to_dep = x.as_tensor_ref(); + all_tensors.insert(to_dep); + this->CtrlDepend(tensor, to_dep); + } + } + + for (const ir::Tensor& t : all_tensors) { + name_to_tensor_.insert({t->name, t}); + } +} + +TensorGroup::~TensorGroup() {} + +bool TensorGroup::Contain(const std::string& name) const { + return name_to_tensor_.find(name) != name_to_tensor_.end(); +} + +void TensorGroup::Insert(const ir::Tensor& tensor) { + name_to_tensor_.insert({tensor->name, tensor}); +} + +ir::Tensor TensorGroup::Get(const std::string& name) { + return name_to_tensor_[name]; +} + +std::set TensorGroup::GetAllTensors() { + std::set all_tensors; + for (const std::pair& p : name_to_tensor_) { + all_tensors.insert(p.second); + } + return all_tensors; +} + +std::vector TensorGroup::GetGenFuncTopoOrder( + const std::vector& func_args) { + std::unordered_map in_degree; + for (const auto& dep_pair : ctrl_dep_) { + const std::unordered_set& dep_tensor_names = dep_pair.second; + in_degree[dep_pair.first] = dep_tensor_names.size(); + } + + std::vector ret; + std::vector stack; + for (const auto& name_tensor : name_to_tensor_) { + if (!in_degree.count(name_tensor.first)) { + stack.emplace_back(name_tensor.first); + } + } + + std::set input_arg_names; + for (const ir::Tensor& arg : func_args) { + input_arg_names.insert(arg->name); + } + for (const std::string& name : output_tensor_names_) { + input_arg_names.erase(name); + } + + while (!stack.empty()) { + const std::string& cur = stack.back(); + stack.pop_back(); + + if (!input_arg_names.count(cur)) { + ret.push_back(name_to_tensor_[cur]); + } + + for (const auto& dep_pair : ctrl_dep_) { + const std::unordered_set& dep_tensor_names = dep_pair.second; + if (dep_tensor_names.count(cur)) { + --in_degree[dep_pair.first]; + if (in_degree[dep_pair.first] == 0) { + stack.emplace_back(dep_pair.first); + } + } + } + } + return ret; +} + +bool TensorGroup::HasMarkedReduceInit(const std::string& tensor_name) const { + return tensor_name_needs_reduce_init_.count(tensor_name); +} + +ir::Tensor TensorGroup::MarkReduceInit(const std::string& tensor_name) { + // TODO(zhhsplendid): add check + tensor_name_needs_reduce_init_.insert(tensor_name); +} + +void TensorGroup::CtrlDepend(const ir::Tensor& tensor, + const ir::Tensor& to_dep) { + ctrl_dep_[tensor->name].insert(to_dep->name); + if (!name_to_tensor_.count(tensor->name)) { + name_to_tensor_[tensor->name] = tensor; + } + if (!name_to_tensor_.count(to_dep->name)) { + name_to_tensor_[to_dep->name] = to_dep; + } +} + +std::set TensorGroup::GetCrtlDepTensors( + const std::string& tensor_name) { + if (!ctrl_dep_.count(tensor_name)) { + return {}; + } + std::set ret; + for (const std::string& dep_name : ctrl_dep_[tensor_name]) { + ret.insert(name_to_tensor_[dep_name]); + } + return ret; +} + +std::string TensorGroup::GetShareMemRootName(const std::string& tensor_name) { + if (!share_memory_tensor_.count(tensor_name)) { + share_memory_tensor_[tensor_name] = tensor_name; + return tensor_name; + } + if (share_memory_tensor_[tensor_name] == tensor_name) { + return tensor_name; + } + share_memory_tensor_[tensor_name] = + GetShareMemRootName(share_memory_tensor_[tensor_name]); + return share_memory_tensor_[tensor_name]; +} + +void TensorGroup::ShareMemoryBuffer(const ir::Tensor& tensor, + const ir::Tensor& to_share) { + share_memory_tensor_[GetShareMemRootName(to_share->name)] = + GetShareMemRootName(tensor->name); +} + +absl::flat_hash_map TensorGroup::AllocateBuffers() { + std::unordered_set allocated_roots; + for (auto& name_tensor : name_to_tensor_) { + std::string root_name = GetShareMemRootName(name_tensor.first); + + // Allocate root buffer + if (!allocated_roots.count(root_name)) { + ir::Tensor root_tensor = name_to_tensor_[root_name]; + if (!root_tensor->buffer.defined() && !root_tensor->type().is_void()) { + root_tensor->WithBuffer(); + VLOG(6) << "Bind root_tensor " << root_name << " with buffer " + << root_tensor->buffer->name; + } + allocated_roots.insert(root_name); + } + + // Share buffer + if (root_name != name_tensor.first) { + ir::Tensor& root_tensor = name_to_tensor_[root_name]; + ir::Tensor& tensor = name_tensor.second; + + auto keep_shape = root_tensor->buffer->shape; + tensor->Bind(root_tensor->buffer); + root_tensor->buffer->shape = keep_shape; + tensor->buffer->shape = keep_shape; + VLOG(6) << "Share buffer " << root_name << " with " << name_tensor.first; + } + } + + return name_to_tensor_; +} + +} // namespace ast_gen_ius +} // namespace cinn diff --git a/paddle/cinn/ast_gen_ius/tensor_group.h b/paddle/cinn/ast_gen_ius/tensor_group.h new file mode 100644 index 0000000000000..1fa37c730c455 --- /dev/null +++ b/paddle/cinn/ast_gen_ius/tensor_group.h @@ -0,0 +1,82 @@ +// Copyright (c) 2023 CINN Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once +#include + +#include +#include +#include +#include +#include + +#include "paddle/cinn/ir/ir.h" +#include "paddle/cinn/ir/ir_base.h" +#include "paddle/cinn/ir/tensor.h" + +namespace cinn { +namespace ast_gen_ius { + +/* Collection used for Tensors, used in AST generation */ +class TensorGroup { + public: + explicit TensorGroup(const std::vector& tensors); + ~TensorGroup(); + + bool Contain(const std::string& name) const; + + void Insert(const ir::Tensor& tensor); + + ir::Tensor Get(const std::string& name); + + std::set GetAllTensors(); + + void CtrlDepend(const ir::Tensor& tensor, const ir::Tensor& to_dep); + + std::set GetCrtlDepTensors(const std::string& tensor_name); + + std::string GetShareMemRootName(const std::string& tensor_name); + + void ShareMemoryBuffer(const ir::Tensor& tensor, const ir::Tensor& to_share); + + absl::flat_hash_map AllocateBuffers(); + + // Returns tensors in topological order and remove those args + // Becuase the order is used for generating function body, we don't have to + // generate args + std::vector GetGenFuncTopoOrder( + const std::vector& func_args = {}); + + bool HasMarkedReduceInit(const std::string& tensor_name) const; + + // Marks a tensor needs to do reduce init + ir::Tensor MarkReduceInit(const std::string& tensor_name); + + private: + std::set output_tensor_names_; + + absl::flat_hash_map name_to_tensor_; + + // Stores vector of tensor names, which the key tensor depends on + std::unordered_map> ctrl_dep_; + + // Keeps Union Find Set style, each tensor name whose buffer is shared maps to + // the same name tensor + std::unordered_map share_memory_tensor_; + + std::unordered_set tensor_name_needs_reduce_init_; +}; + +} // namespace ast_gen_ius +} // namespace cinn diff --git a/paddle/cinn/ast_gen_ius/tensor_group_test.cc b/paddle/cinn/ast_gen_ius/tensor_group_test.cc new file mode 100644 index 0000000000000..3711419da9c56 --- /dev/null +++ b/paddle/cinn/ast_gen_ius/tensor_group_test.cc @@ -0,0 +1,61 @@ +// Copyright (c) 2023 CINN Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include +#include + +#include "paddle/cinn/ast_gen_ius/tensor_group.h" +#include "paddle/cinn/ir/ir.h" +#include "paddle/cinn/ir/ir_base.h" +#include "paddle/cinn/ir/tensor.h" +#include "paddle/cinn/lang/compute.h" +#include "paddle/cinn/lang/placeholder.h" + +namespace cinn { +namespace ast_gen_ius { + +using ir::Expr; +using ir::Tensor; +using ir::Var; +using lang::Compute; +using lang::Placeholder; + +TEST(TensorGroup, Easy) { + auto M = Expr(100); + auto N = Expr(15); + Placeholder A("A", {M, N}); + + Tensor B = Compute( + {M, N}, [=](Var i, Var j) -> Expr { return A(i, j) + 1.f; }, "B"); + + TensorGroup tensor_group({B}); + + ASSERT_TRUE(tensor_group.Contain("A")); + ASSERT_TRUE(tensor_group.Contain("B")); + ASSERT_EQ(tensor_group.Get("B")->name, "B"); + ASSERT_EQ(tensor_group.Get("A")->name, "A"); + ASSERT_EQ(tensor_group.GetAllTensors().size(), 2UL); + + ASSERT_EQ(tensor_group.GetCrtlDepTensors("A").size(), 0UL); + ASSERT_EQ(tensor_group.GetCrtlDepTensors("B").size(), 1UL); + ASSERT_TRUE(tensor_group.GetCrtlDepTensors("B").count(A)); + + std::vector topo_tensors = + tensor_group.GetGenFuncTopoOrder({A.tensor(), B}); + ASSERT_EQ(topo_tensors.size(), 1UL); + ASSERT_EQ(topo_tensors[0]->name, "B"); +} + +} // namespace ast_gen_ius +} // namespace cinn diff --git a/paddle/cinn/auto_schedule/cost_model/CMakeLists.txt b/paddle/cinn/auto_schedule/cost_model/CMakeLists.txt index 7f514471a4f7a..17af89c8ae2a1 100644 --- a/paddle/cinn/auto_schedule/cost_model/CMakeLists.txt +++ b/paddle/cinn/auto_schedule/cost_model/CMakeLists.txt @@ -3,7 +3,8 @@ core_gather_headers() gather_srcs(cinnapi_src SRCS xgb_cost_model.cc expr_cost_model.cc feature.cc feature_extractor.cc) -cinn_cc_test(test_xgb_cost_model SRCS xgb_cost_model_test.cc DEPS cinncore) +# TODO(zhhsplendid): enable this test again +#cinn_cc_test(test_xgb_cost_model SRCS xgb_cost_model_test.cc DEPS cinncore) cinn_cc_test(test_feature_extractor SRCS feature_extractor_test.cc DEPS cinncore) cinn_cc_test(test_feature SRCS feature_test.cc DEPS cinncore) diff --git a/paddle/cinn/auto_schedule/search_strategy/CMakeLists.txt b/paddle/cinn/auto_schedule/search_strategy/CMakeLists.txt index 7f393dfb39837..ab1db5f7bb1bd 100644 --- a/paddle/cinn/auto_schedule/search_strategy/CMakeLists.txt +++ b/paddle/cinn/auto_schedule/search_strategy/CMakeLists.txt @@ -4,5 +4,6 @@ core_gather_headers() gather_srcs(cinnapi_src SRCS evolutionary_search.cc) -cinn_cc_test(test_evolutionary_search SRCS evolutionary_search_test.cc DEPS - cinncore test_program_builder) +# TODO(zhhsplendid): enable this test again +#cinn_cc_test(test_evolutionary_search SRCS evolutionary_search_test.cc DEPS +# cinncore test_program_builder) diff --git a/paddle/cinn/backends/codegen_c.cc b/paddle/cinn/backends/codegen_c.cc index cffebdc1a6736..3352a458ceceb 100644 --- a/paddle/cinn/backends/codegen_c.cc +++ b/paddle/cinn/backends/codegen_c.cc @@ -23,7 +23,6 @@ #include "paddle/cinn/ir/op/ir_operators.h" #include "paddle/cinn/ir/utils/ir_verify.h" #include "paddle/cinn/optim/ir_simplify.h" -#include "paddle/cinn/optim/remove_nested_block.h" #include "paddle/cinn/runtime/cpu/thread_backend.h" #include "paddle/cinn/runtime/intrinsic.h" #include "paddle/cinn/utils/string.h" @@ -645,7 +644,7 @@ void CodeGenC::Visit(const ir::_LoweredFunc_ *op) { Expr func_body = ir::Block::Make(new_body); - optim::RemoveNestedBlock(&func_body); + optim::SimplifyBlocks(&func_body); IrPrinter::Visit(func_body); } diff --git a/paddle/cinn/backends/codegen_cuda_dev.cc b/paddle/cinn/backends/codegen_cuda_dev.cc index 018f935482c7f..e33154f0c0129 100644 --- a/paddle/cinn/backends/codegen_cuda_dev.cc +++ b/paddle/cinn/backends/codegen_cuda_dev.cc @@ -24,7 +24,6 @@ #include "paddle/cinn/ir/op/ir_operators.h" #include "paddle/cinn/ir/utils/ir_verify.h" #include "paddle/cinn/optim/ir_simplify.h" -#include "paddle/cinn/optim/remove_nested_block.h" namespace cinn { namespace backends { @@ -141,7 +140,7 @@ void CodeGenCUDA_Dev::Visit(const ir::_LoweredFunc_ *op) { Expr func_body = ir::Block::Make(new_body); - optim::RemoveNestedBlock(&func_body); + optim::SimplifyBlocks(&func_body); // Make sure that the function's body is wrapped by a block if (!func_body.As()) { func_body = ir::Block::Make({func_body}); diff --git a/paddle/cinn/backends/llvm/codegen_x86.cc b/paddle/cinn/backends/llvm/codegen_x86.cc index 28159f9ea4e4f..ccae02ac5746b 100644 --- a/paddle/cinn/backends/llvm/codegen_x86.cc +++ b/paddle/cinn/backends/llvm/codegen_x86.cc @@ -28,7 +28,7 @@ #include "paddle/cinn/common/target.h" #include "paddle/cinn/ir/ir.h" #include "paddle/cinn/ir/op/ir_operators.h" -#include "paddle/cinn/optim/collect_undefined_vars.h" +#include "paddle/cinn/ir/utils/ir_nodes_collector.h" #include "paddle/cinn/runtime/intrinsic.h" namespace cinn::backends { @@ -98,7 +98,7 @@ void CodeGenX86::CreateParallelLaunch(Expr body, int num_task) { llvm::Function::PrivateLinkage, "__parallel_lambda", m_); - std::vector vars = optim::CollectUndefinedVars(&body); + std::vector vars = ir::CollectUndefinedVars(&body); uint64_t nbytes; auto* data = PackVars(vars, &nbytes); diff --git a/paddle/cinn/hlir/dialect/CMakeLists.txt b/paddle/cinn/hlir/dialect/CMakeLists.txt index 5d30ab6d34504..3787fdf2b4b08 100755 --- a/paddle/cinn/hlir/dialect/CMakeLists.txt +++ b/paddle/cinn/hlir/dialect/CMakeLists.txt @@ -1,2 +1,2 @@ -add_subdirectory(cinn_dialect) -add_subdirectory(runtime_dialect) +add_subdirectory(operator) +add_subdirectory(runtime) diff --git a/paddle/cinn/hlir/dialect/generated/cinn_ops.parsed.yaml b/paddle/cinn/hlir/dialect/generated/cinn_ops.parsed.yaml deleted file mode 100644 index b345bb699084e..0000000000000 --- a/paddle/cinn/hlir/dialect/generated/cinn_ops.parsed.yaml +++ /dev/null @@ -1,31 +0,0 @@ -- name: add - inputs: - - typename: Tensor - name: x - optional: false - no_need_buffer: false - data_transform: {} - - typename: Tensor - name: y - optional: false - no_need_buffer: false - data_transform: {} - attrs: [] - outputs: - - {typename: Tensor, name: out, optional: false, intermediate: false} - no_need_buffer: null - data_transform: null - infer_meta: - func: ElementwiseInferMeta - param: [x, y] - kernel: - func: [add] - param: [x, y] - backend: null - layout: null - data_type: null - dispatch: {add: null} - force_backend: null - inplace: {out: x} - view: null - backward: null diff --git a/paddle/cinn/hlir/dialect/cinn_dialect/CMakeLists.txt b/paddle/cinn/hlir/dialect/operator/CMakeLists.txt similarity index 100% rename from paddle/cinn/hlir/dialect/cinn_dialect/CMakeLists.txt rename to paddle/cinn/hlir/dialect/operator/CMakeLists.txt diff --git a/paddle/cinn/hlir/dialect/cinn_dialect/ir/CMakeLists.txt b/paddle/cinn/hlir/dialect/operator/ir/CMakeLists.txt similarity index 71% rename from paddle/cinn/hlir/dialect/cinn_dialect/ir/CMakeLists.txt rename to paddle/cinn/hlir/dialect/operator/ir/CMakeLists.txt index 5fa53f74cc4a9..896a727f7e59f 100644 --- a/paddle/cinn/hlir/dialect/cinn_dialect/ir/CMakeLists.txt +++ b/paddle/cinn/hlir/dialect/operator/ir/CMakeLists.txt @@ -1,31 +1,30 @@ -# TODO(Aurelius84): new_ir_compiler depends on pd_dialect and could +# TODO(Aurelius84): new_ir_compiler depends on pd_op_dialect and could # not found under CINN_ONLY mode if(NOT CINN_ONLY) set(CINN_DIALECT_BINARY_DIR - "${PADDLE_BINARY_DIR}/paddle/cinn/hlir/dialect/cinn_dialect/ir") + "${PADDLE_BINARY_DIR}/paddle/cinn/hlir/dialect/operator/ir") - # Generate cinn_dialect files defining op using op_gen_file + # Generate cinn_op_dialect files defining op using op_gen_file set(cinn_op_gen_parsed_yaml_file ${PADDLE_SOURCE_DIR}/paddle/fluid/operators/generator/parse_op.py) set(cinn_op_gen_file - ${PADDLE_SOURCE_DIR}/paddle/fluid/ir/dialect/op_generator/op_gen.py) + ${PADDLE_SOURCE_DIR}/paddle/fluid/pir/dialect/op_generator/op_gen.py) set(cinn_op_compat_yaml_file ${PADDLE_SOURCE_DIR}/paddle/phi/api/yaml/op_compat.yaml) set(cinn_op_yaml_file - ${PADDLE_SOURCE_DIR}/paddle/cinn/hlir/dialect/cinn_dialect/ir/cinn_ops.yaml - ) + ${PADDLE_SOURCE_DIR}/paddle/cinn/hlir/dialect/operator/ir/ops.yaml) set(parsed_op_dir ${PADDLE_SOURCE_DIR}/paddle/cinn/hlir/dialect/generated) - set(cinn_op_parsed_yaml_file ${parsed_op_dir}/cinn_ops.parsed.yaml) + set(cinn_op_parsed_yaml_file ${parsed_op_dir}/ops.parsed.yaml) set(cinn_op_parsed_yaml_files ${cinn_op_parsed_yaml_file}) set(cinn_op_namespace cinn,dialect) - set(cinn_dialect_name cinn) + set(cinn_op_dialect_name cinn_op) set(cinn_op_header_file ${CINN_DIALECT_BINARY_DIR}/cinn_op.h) set(cinn_op_source_file ${CINN_DIALECT_BINARY_DIR}/cinn_op.cc) set(cinn_op_header_file_tmp ${cinn_op_header_file}.tmp) @@ -44,7 +43,7 @@ if(NOT CINN_ONLY) ${PYTHON_EXECUTABLE} ${cinn_op_gen_file} --op_yaml_files ${cinn_op_parsed_yaml_files} --op_compat_yaml_file ${cinn_op_compat_yaml_file} --namespaces ${cinn_op_namespace} - --dialect_name ${cinn_dialect_name} --op_def_h_file + --dialect_name ${cinn_op_dialect_name} --op_def_h_file ${cinn_op_header_file_tmp} --op_def_cc_file ${cinn_op_source_file_tmp} COMMAND ${CMAKE_COMMAND} -E copy_if_different ${cinn_op_header_file_tmp} ${cinn_op_header_file} @@ -54,8 +53,8 @@ if(NOT CINN_ONLY) ${cinn_op_compat_yaml_file} VERBATIM) - cinn_cc_library(cinn_dialect SRCS cinn_dialect.cc ${cinn_op_source_file} DEPS - pd_dialect) + cinn_cc_library(cinn_op_dialect SRCS op_dialect.cc ${cinn_op_source_file} + DEPS pd_op_dialect) - target_include_directories(cinn_dialect PRIVATE ${CINN_DIALECT_BINARY_DIR}) + target_include_directories(cinn_op_dialect PRIVATE ${CINN_DIALECT_BINARY_DIR}) endif() diff --git a/paddle/cinn/hlir/dialect/cinn_dialect/ir/cinn_dialect.cc b/paddle/cinn/hlir/dialect/operator/ir/op_dialect.cc similarity index 68% rename from paddle/cinn/hlir/dialect/cinn_dialect/ir/cinn_dialect.cc rename to paddle/cinn/hlir/dialect/operator/ir/op_dialect.cc index 9e8ccfb6492e4..d8a3bc7b8b35a 100644 --- a/paddle/cinn/hlir/dialect/cinn_dialect/ir/cinn_dialect.cc +++ b/paddle/cinn/hlir/dialect/operator/ir/op_dialect.cc @@ -12,31 +12,32 @@ // See the License for the specific language governing permissions and // limitations under the License. -#include "paddle/cinn/hlir/dialect/cinn_dialect/ir/cinn_dialect.h" +#include "paddle/cinn/hlir/dialect/operator/ir/op_dialect.h" // NOTE(chenxi67): File cinn_op.h is generated by op_gen.py, see details in // paddle/cinn/hlir/dialect/CMakeLists.txt. -#include "paddle/cinn/hlir/dialect/cinn_dialect/ir/cinn_op.h" +#include "paddle/cinn/hlir/dialect/operator/ir/cinn_op.h" namespace cinn { namespace dialect { -CinnDialect::CinnDialect(::ir::IrContext* context) - : ::ir::Dialect( - name(), context, ::ir::TypeId::get()) { +OperatorDialect::OperatorDialect(::pir::IrContext* context) + : ::pir::Dialect(name(), + context, + ::pir::TypeId::get()) { this->initialize(); } -void CinnDialect::initialize() { +void OperatorDialect::initialize() { // NOTE(chenxi67): GET_OP_LIST is defined in cinn_op.h which is // generated by op_gen.py, see details in // paddle/cinn/hlir/dialect/CMakeLists.txt. RegisterOps< #define GET_OP_LIST -#include "paddle/cinn/hlir/dialect/cinn_dialect/ir/cinn_op.h" // NOLINT +#include "paddle/cinn/hlir/dialect/operator/ir/cinn_op.h" // NOLINT >(); } } // namespace dialect } // namespace cinn -IR_DEFINE_EXPLICIT_TYPE_ID(cinn::dialect::CinnDialect) +IR_DEFINE_EXPLICIT_TYPE_ID(cinn::dialect::OperatorDialect) diff --git a/paddle/cinn/hlir/dialect/cinn_dialect/ir/cinn_dialect.h b/paddle/cinn/hlir/dialect/operator/ir/op_dialect.h similarity index 75% rename from paddle/cinn/hlir/dialect/cinn_dialect/ir/cinn_dialect.h rename to paddle/cinn/hlir/dialect/operator/ir/op_dialect.h index 77fb96863ad37..58a0487e9e8f9 100644 --- a/paddle/cinn/hlir/dialect/cinn_dialect/ir/cinn_dialect.h +++ b/paddle/cinn/hlir/dialect/operator/ir/op_dialect.h @@ -14,16 +14,16 @@ #pragma once -#include "paddle/ir/core/dialect.h" +#include "paddle/pir/core/dialect.h" namespace cinn { namespace dialect { -class CinnDialect : public ::ir::Dialect { +class OperatorDialect : public ::pir::Dialect { public: - explicit CinnDialect(::ir::IrContext* context); + explicit OperatorDialect(::pir::IrContext* context); - static const char* name() { return "cinn"; } + static const char* name() { return "cinn_op"; } private: void initialize(); @@ -32,4 +32,4 @@ class CinnDialect : public ::ir::Dialect { } // namespace dialect } // namespace cinn -IR_DECLARE_EXPLICIT_TYPE_ID(cinn::dialect::CinnDialect) +IR_DECLARE_EXPLICIT_TYPE_ID(cinn::dialect::OperatorDialect) diff --git a/paddle/cinn/hlir/dialect/cinn_dialect/ir/cinn_ops.yaml b/paddle/cinn/hlir/dialect/operator/ir/ops.yaml similarity index 100% rename from paddle/cinn/hlir/dialect/cinn_dialect/ir/cinn_ops.yaml rename to paddle/cinn/hlir/dialect/operator/ir/ops.yaml diff --git a/paddle/cinn/hlir/dialect/runtime_dialect/CMakeLists.txt b/paddle/cinn/hlir/dialect/runtime/CMakeLists.txt similarity index 100% rename from paddle/cinn/hlir/dialect/runtime_dialect/CMakeLists.txt rename to paddle/cinn/hlir/dialect/runtime/CMakeLists.txt diff --git a/paddle/cinn/hlir/dialect/runtime/ir/CMakeLists.txt b/paddle/cinn/hlir/dialect/runtime/ir/CMakeLists.txt new file mode 100644 index 0000000000000..6023117faee09 --- /dev/null +++ b/paddle/cinn/hlir/dialect/runtime/ir/CMakeLists.txt @@ -0,0 +1,4 @@ +if(NOT CINN_ONLY) + cinn_cc_library(cinn_runtime_dialect SRCS runtime_dialect.cc jit_kernel_op.cc + DEPS pir_core) +endif() diff --git a/paddle/cinn/hlir/dialect/runtime_dialect/ir/jit_kernel_op.cc b/paddle/cinn/hlir/dialect/runtime/ir/jit_kernel_op.cc similarity index 80% rename from paddle/cinn/hlir/dialect/runtime_dialect/ir/jit_kernel_op.cc rename to paddle/cinn/hlir/dialect/runtime/ir/jit_kernel_op.cc index 49e3685a8475a..ed3d4a4045c59 100644 --- a/paddle/cinn/hlir/dialect/runtime_dialect/ir/jit_kernel_op.cc +++ b/paddle/cinn/hlir/dialect/runtime/ir/jit_kernel_op.cc @@ -12,10 +12,10 @@ // See the License for the specific language governing permissions and // limitations under the License. -#include "paddle/cinn/hlir/dialect/runtime_dialect/ir/jit_kernel_op.h" +#include "paddle/cinn/hlir/dialect/runtime/ir/jit_kernel_op.h" -#include "paddle/ir/core/builtin_attribute.h" -#include "paddle/ir/core/enforce.h" +#include "paddle/pir/core/builtin_attribute.h" +#include "paddle/pir/core/enforce.h" namespace cinn { namespace dialect { @@ -28,13 +28,13 @@ void JitKernelOp::Verify() { auto& attributes = this->attributes(); IR_ENFORCE(attributes.count(kAttrName) > 0 && - attributes.at(kAttrName).isa<::ir::PointerAttribute>(), + attributes.at(kAttrName).isa<::pir::PointerAttribute>(), "Type of attribute: instruction is not right."); } hlir::framework::Instruction* JitKernelOp::instruction() { void* ptr = - attributes().at(kAttrName).dyn_cast().data(); + attributes().at(kAttrName).dyn_cast<::pir::PointerAttribute>().data(); return reinterpret_cast(ptr); } diff --git a/paddle/cinn/hlir/dialect/runtime_dialect/ir/jit_kernel_op.h b/paddle/cinn/hlir/dialect/runtime/ir/jit_kernel_op.h similarity index 91% rename from paddle/cinn/hlir/dialect/runtime_dialect/ir/jit_kernel_op.h rename to paddle/cinn/hlir/dialect/runtime/ir/jit_kernel_op.h index 37b9c66bb6e17..f410e4d46c021 100644 --- a/paddle/cinn/hlir/dialect/runtime_dialect/ir/jit_kernel_op.h +++ b/paddle/cinn/hlir/dialect/runtime/ir/jit_kernel_op.h @@ -14,7 +14,7 @@ #pragma once -#include "paddle/ir/core/op_base.h" +#include "paddle/pir/core/op_base.h" namespace cinn { @@ -40,10 +40,10 @@ namespace dialect { * temporarily, and will spilt executor information like * scope, inputs, outputs into InterpretorCore module. */ -class JitKernelOp : public ::ir::Op { +class JitKernelOp : public ::pir::Op { public: using Op::Op; - static const char* name() { return "cinn.jit_kernel"; } + static const char* name() { return "cinn_runtime.jit_kernel"; } // TODO(Aurelius84): Think deeply what should contains static constexpr uint32_t attributes_num = 1; static constexpr char* kAttrName = "instruction"; diff --git a/paddle/cinn/hlir/dialect/runtime_dialect/ir/runtime_dialect.cc b/paddle/cinn/hlir/dialect/runtime/ir/runtime_dialect.cc similarity index 73% rename from paddle/cinn/hlir/dialect/runtime_dialect/ir/runtime_dialect.cc rename to paddle/cinn/hlir/dialect/runtime/ir/runtime_dialect.cc index c21d21f11213e..40fd092e1329e 100644 --- a/paddle/cinn/hlir/dialect/runtime_dialect/ir/runtime_dialect.cc +++ b/paddle/cinn/hlir/dialect/runtime/ir/runtime_dialect.cc @@ -12,15 +12,16 @@ // See the License for the specific language governing permissions and // limitations under the License. -#include "paddle/cinn/hlir/dialect/runtime_dialect/ir/runtime_dialect.h" -#include "paddle/cinn/hlir/dialect/runtime_dialect/ir/jit_kernel_op.h" +#include "paddle/cinn/hlir/dialect/runtime/ir/runtime_dialect.h" +#include "paddle/cinn/hlir/dialect/runtime/ir/jit_kernel_op.h" namespace cinn { namespace dialect { -RuntimeDialect::RuntimeDialect(::ir::IrContext* context) - : ::ir::Dialect( - name(), context, ::ir::TypeId::get()) { +RuntimeDialect::RuntimeDialect(::pir::IrContext* context) + : ::pir::Dialect(name(), + context, + ::pir::TypeId::get()) { this->initialize(); } diff --git a/paddle/cinn/hlir/dialect/runtime_dialect/ir/runtime_dialect.h b/paddle/cinn/hlir/dialect/runtime/ir/runtime_dialect.h similarity index 81% rename from paddle/cinn/hlir/dialect/runtime_dialect/ir/runtime_dialect.h rename to paddle/cinn/hlir/dialect/runtime/ir/runtime_dialect.h index a35c7a24b8d7f..8ba0af9334498 100644 --- a/paddle/cinn/hlir/dialect/runtime_dialect/ir/runtime_dialect.h +++ b/paddle/cinn/hlir/dialect/runtime/ir/runtime_dialect.h @@ -14,16 +14,16 @@ #pragma once -#include "paddle/ir/core/dialect.h" +#include "paddle/pir/core/dialect.h" namespace cinn { namespace dialect { -class RuntimeDialect : public ::ir::Dialect { +class RuntimeDialect : public ::pir::Dialect { public: - explicit RuntimeDialect(::ir::IrContext* context); + explicit RuntimeDialect(::pir::IrContext* context); - static const char* name() { return "cinn"; } + static const char* name() { return "cinn_runtime"; } private: void initialize(); diff --git a/paddle/cinn/hlir/dialect/runtime_dialect/ir/CMakeLists.txt b/paddle/cinn/hlir/dialect/runtime_dialect/ir/CMakeLists.txt deleted file mode 100644 index 1df80a5bb3f75..0000000000000 --- a/paddle/cinn/hlir/dialect/runtime_dialect/ir/CMakeLists.txt +++ /dev/null @@ -1,4 +0,0 @@ -if(NOT CINN_ONLY) - cinn_cc_library(runtime_dialect SRCS runtime_dialect.cc jit_kernel_op.cc DEPS - ir_core) -endif() diff --git a/paddle/cinn/hlir/framework/CMakeLists.txt b/paddle/cinn/hlir/framework/CMakeLists.txt index d14ffa70234fc..5e202578b125c 100755 --- a/paddle/cinn/hlir/framework/CMakeLists.txt +++ b/paddle/cinn/hlir/framework/CMakeLists.txt @@ -23,13 +23,13 @@ gather_srcs( accuracy_checker.cc visualize_helper.cc) -# TODO(Aurelius84): new_ir_compiler depends on pd_dialect and could +# TODO(Aurelius84): new_ir_compiler depends on pd_op_dialect and could # not found under CINN_ONLY mode if(NOT CINN_ONLY) cinn_cc_library(new_ir_compiler SRCS new_ir_compiler.cc DEPS cinnapi - pd_dialect) + pd_op_dialect) cinn_cc_library(convert_to_dialect SRCS convert_to_dialect.cc DEPS cinnapi - cinn_dialect) + cinn_op_dialect) endif() if(WITH_CUDA) diff --git a/paddle/cinn/hlir/framework/convert_to_dialect.cc b/paddle/cinn/hlir/framework/convert_to_dialect.cc index 306e27dc1fea5..f76b49a54555f 100644 --- a/paddle/cinn/hlir/framework/convert_to_dialect.cc +++ b/paddle/cinn/hlir/framework/convert_to_dialect.cc @@ -17,34 +17,34 @@ #include #include -#include "paddle/cinn/hlir/dialect/runtime_dialect/ir/jit_kernel_op.h" -#include "paddle/cinn/hlir/dialect/runtime_dialect/ir/runtime_dialect.h" +#include "paddle/cinn/hlir/dialect/runtime/ir/jit_kernel_op.h" +#include "paddle/cinn/hlir/dialect/runtime/ir/runtime_dialect.h" #include "paddle/cinn/hlir/framework/program.h" -#include "paddle/ir/core/builtin_attribute.h" -#include "paddle/ir/core/program.h" +#include "paddle/pir/core/builtin_attribute.h" +#include "paddle/pir/core/program.h" namespace cinn { namespace hlir { namespace framework { -std::unique_ptr<::ir::Program> ConvertToRuntimeDialect( +std::unique_ptr<::pir::Program> ConvertToRuntimeDialect( const hlir::framework::Program& program) { - ::ir::IrContext* ctx = ::ir::IrContext::Instance(); + ::pir::IrContext* ctx = ::pir::IrContext::Instance(); ctx->GetOrRegisterDialect(); - auto ir_program = std::make_unique<::ir::Program>(ctx); + auto ir_program = std::make_unique<::pir::Program>(ctx); std::string jit_op_name = dialect::JitKernelOp::name(); - ::ir::OpInfo op_info = ctx->GetRegisteredOpInfo(jit_op_name); + ::pir::OpInfo op_info = ctx->GetRegisteredOpInfo(jit_op_name); auto& instrs = program.GetRunInstructions(); for (auto& instr : instrs) { - std::unordered_map op_attrs{ + std::unordered_map op_attrs{ {dialect::JitKernelOp::kAttrName, - ::ir::PointerAttribute::get(ctx, instr.get())}, + ::pir::PointerAttribute::get(ctx, instr.get())}, }; - ::ir::Operation* cinn_op = - ::ir::Operation::Create({}, op_attrs, {}, op_info); + ::pir::Operation* cinn_op = + ::pir::Operation::Create({}, op_attrs, {}, op_info); ir_program->block()->push_back(cinn_op); } return std::move(ir_program); diff --git a/paddle/cinn/hlir/framework/convert_to_dialect.h b/paddle/cinn/hlir/framework/convert_to_dialect.h index a88b5222b63bd..7ea0a2ace40c7 100644 --- a/paddle/cinn/hlir/framework/convert_to_dialect.h +++ b/paddle/cinn/hlir/framework/convert_to_dialect.h @@ -16,16 +16,16 @@ #include -namespace ir { +namespace pir { class Program; -} // namespace ir +} // namespace pir namespace cinn { namespace hlir { namespace framework { class Program; -std::unique_ptr<::ir::Program> ConvertToRuntimeDialect( +std::unique_ptr<::pir::Program> ConvertToRuntimeDialect( const hlir::framework::Program& program); } // namespace framework diff --git a/paddle/cinn/hlir/framework/new_ir/group.h b/paddle/cinn/hlir/framework/new_ir/group.h index b62c315873c70..1a67a02e58ca9 100644 --- a/paddle/cinn/hlir/framework/new_ir/group.h +++ b/paddle/cinn/hlir/framework/new_ir/group.h @@ -18,7 +18,7 @@ #include "paddle/cinn/hlir/framework/new_ir/utils.h" #include "paddle/cinn/hlir/framework/op.h" -#include "paddle/ir/core/operation.h" +#include "paddle/pir/core/operation.h" namespace cinn { namespace hlir { @@ -29,12 +29,12 @@ using framework::OpPatternKind; // TODO(Aurelius84): Need to be replaced with CinnGroupOp struct Group { public: - explicit Group(const std::vector<::ir::Operation*>& group_ops) + explicit Group(const std::vector<::pir::Operation*>& group_ops) : ops(group_ops) { Initialize(); } - explicit Group(std::initializer_list<::ir::Operation*> group_ops) + explicit Group(std::initializer_list<::pir::Operation*> group_ops) : ops(group_ops) { Initialize(); } @@ -42,7 +42,7 @@ struct Group { int group_id; std::string fn_name; OpPatternKind op_pattern_kind; - std::vector<::ir::Operation*> ops; + std::vector<::pir::Operation*> ops; std::vector input_names; std::vector output_names; diff --git a/paddle/cinn/hlir/framework/new_ir/op_lowering_impl.cc b/paddle/cinn/hlir/framework/new_ir/op_lowering_impl.cc index d291aba2e406e..235d545dc331f 100644 --- a/paddle/cinn/hlir/framework/new_ir/op_lowering_impl.cc +++ b/paddle/cinn/hlir/framework/new_ir/op_lowering_impl.cc @@ -23,7 +23,7 @@ #include "paddle/cinn/hlir/framework/new_ir/utils.h" #include "paddle/cinn/lang/placeholder.h" #include "paddle/cinn/utils/attribute_util.h" -#include "paddle/fluid/ir/dialect/paddle_dialect/ir/pd_type.h" +#include "paddle/fluid/pir/dialect/operator/ir/op_type.h" #include "paddle/phi/core/ddim.h" PD_DECLARE_bool(cinn_use_cuda_vectorize); @@ -39,7 +39,7 @@ using framework::OpPatternKind; using framework::StrategyFunction; namespace details { -ir::Tensor GetTensor(const ::ir::Value& value) { +ir::Tensor GetTensor(const ::pir::Value& value) { auto type_info = value.type().dyn_cast(); auto in_shape = phi::vectorize(type_info.dims()); auto dtype = type_info.dtype(); @@ -49,9 +49,9 @@ ir::Tensor GetTensor(const ::ir::Value& value) { } std::vector CollectInputTensor( - const ::ir::Operation* op, + const ::pir::Operation* op, std::vector* func_args, - std::unordered_map<::ir::Value, ir::Tensor>* tensor_map) { + std::unordered_map<::pir::Value, ir::Tensor>* tensor_map) { std::vector tensors; for (auto& operand : op->operands()) { CHECK(operand); @@ -72,7 +72,7 @@ std::vector CollectInputTensor( return tensors; } -void CollectOutputInfo(const ::ir::Operation* op, +void CollectOutputInfo(const ::pir::Operation* op, std::vector* out_types, std::vector>* out_shapes) { auto op_results = op->results(); @@ -88,7 +88,7 @@ void CollectOutputInfo(const ::ir::Operation* op, } } -NodeAttr CollectAttrs(const ::ir::Operation& op) { +NodeAttr CollectAttrs(const ::pir::Operation& op) { NodeAttr node_attrs; VLOG(4) << "op.attributes():" << op.attributes().size(); auto attrs = utils::ConvertAttributes(op.attributes()); @@ -134,18 +134,18 @@ std::vector OpLowererImpl::Lower(const GroupPtr& group, } } -bool OpLowererImpl::ElementwiseScheduleDetermineFunction(::ir::Operation* op) { +bool OpLowererImpl::ElementwiseScheduleDetermineFunction(::pir::Operation* op) { return true; } -bool OpLowererImpl::ReduceScheduleDetermineFunction(::ir::Operation* op) { +bool OpLowererImpl::ReduceScheduleDetermineFunction(::pir::Operation* op) { // TODO(Aurelius84): Support this. // auto& op_pattern_dict = Operator::GetAttrs("OpPattern"); // return op_pattern_dict[op] == framework::kReduction; return true; } -bool OpLowererImpl::NonFusibleScheduleDetermineFunction(::ir::Operation* op) { +bool OpLowererImpl::NonFusibleScheduleDetermineFunction(::pir::Operation* op) { return true; } @@ -160,7 +160,7 @@ std::vector OpLowererImpl::LowerGroup( return LowerCustomCall(group); } std::vector group_func_arg_tensors; - std::unordered_map<::ir::Value, ir::Tensor> tensor_map; + std::unordered_map<::pir::Value, ir::Tensor> tensor_map; bool do_op_schedule = apply_group_schedule || apply_op_schedule; std::vector func_bodies = LowerOps(ops, do_op_schedule, @@ -191,8 +191,8 @@ std::vector OpLowererImpl::LowerCustomCall( const GroupPtr& group) { auto& ops = group->ops; CHECK_EQ(ops.size(), 1); - ::ir::Operation* op = ops[0]; - std::unordered_map<::ir::Value, ir::Tensor> tensor_map; + ::pir::Operation* op = ops[0]; + std::unordered_map<::pir::Value, ir::Tensor> tensor_map; std::vector op_func_arg_tensors = details::CollectInputTensor(op, nullptr, &tensor_map); VLOG(4) << "inputs.size(): " << op_func_arg_tensors.size(); @@ -234,7 +234,7 @@ std::vector OpLowererImpl::LowerCustomCall( std::vector OpLowererImpl::PostProcess( const GroupPtr& group, - const std::unordered_map<::ir::Value, ir::Tensor>& tensor_map, + const std::unordered_map<::pir::Value, ir::Tensor>& tensor_map, bool done_op_schedule, ir::IRSchedule* ir_sch, std::vector* group_func_arg_tensors) { @@ -313,11 +313,11 @@ std::vector OpLowererImpl::PostProcess( } std::vector OpLowererImpl::LowerOps( - const std::vector<::ir::Operation*>& ops, + const std::vector<::pir::Operation*>& ops, bool apply_op_schedule, ScheduleDetermineFunction schedule_determine_func, std::vector* group_func_arg_tensors, - std::unordered_map<::ir::Value, ir::Tensor>* tensor_map) { + std::unordered_map<::pir::Value, ir::Tensor>* tensor_map) { auto& strategy = Operator::GetAttrs("CINNStrategy"); std::vector func_bodies; for (auto* op : ops) { @@ -359,8 +359,8 @@ std::vector OpLowererImpl::LowerOps( std::vector OpLowererImpl::DoOpLower( std::shared_ptr op_impl, - const ::ir::Operation* op, - std::unordered_map<::ir::Value, ir::Tensor>* tensor_map, + const ::pir::Operation* op, + std::unordered_map<::pir::Value, ir::Tensor>* tensor_map, std::vector* op_func_arg_tensors) { VLOG(4) << "Do lower with Compute, op: " << op->name(); std::vector cinn_inputs; diff --git a/paddle/cinn/hlir/framework/new_ir/op_lowering_impl.h b/paddle/cinn/hlir/framework/new_ir/op_lowering_impl.h index ffa6218299100..81e36d8bb7b3b 100644 --- a/paddle/cinn/hlir/framework/new_ir/op_lowering_impl.h +++ b/paddle/cinn/hlir/framework/new_ir/op_lowering_impl.h @@ -26,7 +26,7 @@ #include "paddle/cinn/ir/schedule/ir_schedule.h" #include "paddle/cinn/ir/schedule/ir_schedule_util.h" #include "paddle/cinn/lang/packed_func.h" -#include "paddle/ir/core/operation.h" +#include "paddle/pir/core/operation.h" // Fusion Op lowering, there are four kinds of lowering function: // Elementwise/Broadcast/Injective,Reduce,OutEWiseFusable,NonFusible. @@ -43,7 +43,7 @@ using GroupPtr = std::shared_ptr; using common::Target; class OpLowererImpl; -typedef bool (OpLowererImpl::*ScheduleDetermineFunction)(::ir::Operation*); +typedef bool (OpLowererImpl::*ScheduleDetermineFunction)(::pir::Operation*); class OpLowererImpl : public OpLowererImplBase { public: @@ -96,7 +96,7 @@ class OpLowererImpl : public OpLowererImplBase { */ std::vector PostProcess( const GroupPtr& group, - const std::unordered_map<::ir::Value, ir::Tensor>& tensor_map, + const std::unordered_map<::pir::Value, ir::Tensor>& tensor_map, bool done_op_schedule, ir::IRSchedule* ir_sch, std::vector* group_func_arg_tensors); @@ -114,11 +114,11 @@ class OpLowererImpl : public OpLowererImplBase { * @return The lowered func bodies of Op set. */ std::vector LowerOps( - const std::vector<::ir::Operation*>& ops, + const std::vector<::pir::Operation*>& ops, bool apply_op_schedule, ScheduleDetermineFunction schedule_determine_func, std::vector* group_func_arg_tensors, - std::unordered_map<::ir::Value, ir::Tensor>* tensor_map); + std::unordered_map<::pir::Value, ir::Tensor>* tensor_map); /** * @brief Lower an Op to CINN IR. The Compute and Lower processes will be @@ -131,8 +131,8 @@ class OpLowererImpl : public OpLowererImplBase { */ std::vector DoOpLower( std::shared_ptr op_impl, - const ::ir::Operation* op, - std::unordered_map<::ir::Value, ir::Tensor>* tensor_map, + const ::pir::Operation* op, + std::unordered_map<::pir::Value, ir::Tensor>* tensor_map, std::vector* op_func_arg_tensors); /** @@ -148,9 +148,9 @@ class OpLowererImpl : public OpLowererImplBase { // Functions used to determine which Ops to schedule at op level, define a // policy for each type of group. - inline bool ReduceScheduleDetermineFunction(::ir::Operation* op); - inline bool ElementwiseScheduleDetermineFunction(::ir::Operation* op); - inline bool NonFusibleScheduleDetermineFunction(::ir::Operation* op); + inline bool ReduceScheduleDetermineFunction(::pir::Operation* op); + inline bool ElementwiseScheduleDetermineFunction(::pir::Operation* op); + inline bool NonFusibleScheduleDetermineFunction(::pir::Operation* op); private: Target target_; diff --git a/paddle/cinn/hlir/framework/new_ir/utils.cc b/paddle/cinn/hlir/framework/new_ir/utils.cc index 38bfcf05776e0..b027992af8c47 100644 --- a/paddle/cinn/hlir/framework/new_ir/utils.cc +++ b/paddle/cinn/hlir/framework/new_ir/utils.cc @@ -20,9 +20,9 @@ namespace framework { namespace newir { const std::unordered_map CompatibleInfo::OP_NAMES = { - {"pd.full", "fill_constant"}}; + {"pd_op.full", "fill_constant"}}; -std::string CompatibleInfo::OpName(const ::ir::Operation& op) { +std::string CompatibleInfo::OpName(const ::pir::Operation& op) { std::string name = op.name(); if (OP_NAMES.count(name)) { return OP_NAMES.at(name); @@ -36,12 +36,12 @@ std::string CompatibleInfo::OpName(const ::ir::Operation& op) { return cinn_op_name; } -std::string CompatibleInfo::ValueName(const ::ir::Value& value) { +std::string CompatibleInfo::ValueName(const ::pir::Value& value) { return CompatibleInfo::kNamePrefix + - std::to_string(std::hash<::ir::Value>()(value)); + std::to_string(std::hash<::pir::Value>()(value)); } -std::string CompatibleInfo::OpFuncName(const ::ir::Operation& op) { +std::string CompatibleInfo::OpFuncName(const ::pir::Operation& op) { std::string op_name = OpName(op); std::string func_name = cinn::common::Context::Global().NewName("fn_" + op_name); @@ -49,7 +49,7 @@ std::string CompatibleInfo::OpFuncName(const ::ir::Operation& op) { } std::string CompatibleInfo::GroupOpsName( - const std::vector<::ir::Operation*>& ops) { + const std::vector<::pir::Operation*>& ops) { std::string name = "fn"; for (auto* op : ops) { std::string op_name = OpName(*op); @@ -58,7 +58,7 @@ std::string CompatibleInfo::GroupOpsName( return name; } -std::vector CompatibleInfo::InputNames(const ::ir::Operation& op, +std::vector CompatibleInfo::InputNames(const ::pir::Operation& op, bool allow_duplicate) { std::vector names; std::unordered_set repeat; @@ -75,7 +75,7 @@ std::vector CompatibleInfo::InputNames(const ::ir::Operation& op, } std::vector CompatibleInfo::OutputNames( - const ::ir::Operation& op) { + const ::pir::Operation& op) { std::vector names; for (int i = 0; i < op.num_results(); ++i) { auto value = op.result(i); diff --git a/paddle/cinn/hlir/framework/new_ir/utils.h b/paddle/cinn/hlir/framework/new_ir/utils.h index 4c437dd19ef8a..2a70cd9eedc17 100644 --- a/paddle/cinn/hlir/framework/new_ir/utils.h +++ b/paddle/cinn/hlir/framework/new_ir/utils.h @@ -16,7 +16,7 @@ #include #include #include "paddle/cinn/common/context.h" -#include "paddle/ir/core/operation.h" +#include "paddle/pir/core/operation.h" namespace cinn { namespace hlir { @@ -29,18 +29,18 @@ struct CompatibleInfo { // macros or attempt to unify Op name with Paddle and CINN. static const std::unordered_map OP_NAMES; - static std::string OpName(const ::ir::Operation& op); + static std::string OpName(const ::pir::Operation& op); - static std::string ValueName(const ::ir::Value& value); + static std::string ValueName(const ::pir::Value& value); - static std::string OpFuncName(const ::ir::Operation& op); + static std::string OpFuncName(const ::pir::Operation& op); - static std::string GroupOpsName(const std::vector<::ir::Operation*>& ops); + static std::string GroupOpsName(const std::vector<::pir::Operation*>& ops); - static std::vector InputNames(const ::ir::Operation& op, + static std::vector InputNames(const ::pir::Operation& op, bool allow_duplicate = false); - static std::vector OutputNames(const ::ir::Operation& op); + static std::vector OutputNames(const ::pir::Operation& op); }; } // namespace newir diff --git a/paddle/cinn/hlir/framework/new_ir_compiler.cc b/paddle/cinn/hlir/framework/new_ir_compiler.cc index bcc7c0f1c2a05..9172a1d8b052f 100644 --- a/paddle/cinn/hlir/framework/new_ir_compiler.cc +++ b/paddle/cinn/hlir/framework/new_ir_compiler.cc @@ -17,8 +17,8 @@ #include #include "paddle/cinn/hlir/framework/new_ir/utils.h" #include "paddle/cinn/utils/attribute_util.h" -#include "paddle/fluid/ir/dialect/paddle_dialect/ir/pd_type.h" -#include "paddle/ir/core/builtin_type.h" +#include "paddle/fluid/pir/dialect/operator/ir/op_type.h" +#include "paddle/pir/core/builtin_type.h" namespace cinn { namespace hlir { @@ -33,7 +33,7 @@ std::unique_ptr NewIRCompiler::Build() { std::vector groups; for (auto it = program_.block()->begin(); it != program_.block()->end(); ++it) { - std::vector<::ir::Operation*> ops = {*it}; + std::vector<::pir::Operation*> ops = {*it}; groups.push_back(std::make_shared(ops)); } VLOG(4) << "Groups size: " << groups.size(); @@ -123,11 +123,11 @@ std::vector> NewIRCompiler::BuildInstructions( } std::shared_ptr BuildScope(const Target& target, - const ::ir::Program& program) { - std::unordered_set<::ir::Value> visited; + const ::pir::Program& program) { + std::unordered_set<::pir::Value> visited; auto scope = std::make_shared(); - auto create_var = [&](::ir::Value value) { + auto create_var = [&](::pir::Value value) { if (visited.count(value) > 0) return; visited.emplace(value); diff --git a/paddle/cinn/hlir/framework/new_ir_compiler.h b/paddle/cinn/hlir/framework/new_ir_compiler.h index bb18da54bc4f3..62c3d97a21a41 100644 --- a/paddle/cinn/hlir/framework/new_ir_compiler.h +++ b/paddle/cinn/hlir/framework/new_ir_compiler.h @@ -17,7 +17,7 @@ #include #include #include "paddle/cinn/common/macros.h" -#include "paddle/ir/core/program.h" +#include "paddle/pir/core/program.h" #include "paddle/cinn/hlir/framework/graph_compiler.h" #include "paddle/cinn/hlir/framework/op_lowering.h" @@ -30,7 +30,7 @@ namespace framework { // the co-existance with GraphCompiler. class NewIRCompiler final { public: - NewIRCompiler(const ::ir::Program& prog, + NewIRCompiler(const ::pir::Program& prog, const Target& target, const std::shared_ptr& scope) : program_(prog), @@ -45,14 +45,14 @@ class NewIRCompiler final { private: CINN_DISALLOW_COPY_AND_ASSIGN(NewIRCompiler); - std::vector GetOpFunc(const ::ir::Operation& op, int idx); + std::vector GetOpFunc(const ::pir::Operation& op, int idx); void ProcessFunction(const std::vector& lowered_funcs); std::vector> BuildInstructions( const std::vector& groups); - const ::ir::Program& program_; + const ::pir::Program& program_; ir::Module::Builder m_builder_; std::unique_ptr compiler_{nullptr}; Target target_; @@ -60,7 +60,7 @@ class NewIRCompiler final { std::unordered_map func_names_; }; -std::shared_ptr BuildScope(const Target&, const ::ir::Program&); +std::shared_ptr BuildScope(const Target&, const ::pir::Program&); } // namespace framework } // namespace hlir diff --git a/paddle/cinn/ir/lowered_func.cc b/paddle/cinn/ir/lowered_func.cc index 84e8fb3e974e7..5a897e7c334a5 100644 --- a/paddle/cinn/ir/lowered_func.cc +++ b/paddle/cinn/ir/lowered_func.cc @@ -27,7 +27,6 @@ #include "paddle/cinn/ir/buffer.h" #include "paddle/cinn/ir/utils/ir_printer.h" #include "paddle/cinn/ir/utils/ir_visitor.h" -#include "paddle/cinn/optim/tensor_write_tell.h" #include "paddle/cinn/runtime/intrinsic.h" #include "paddle/cinn/utils/string.h" @@ -209,8 +208,7 @@ void _LoweredFunc_::AllocTempBuffer() {} void _LoweredFunc_::PrepareBufferCastExprs(bool with_expr_gen_tensor) { buffer_data_cast_exprs.clear(); // collect write. - optim::TensorWriteTeller write_teller; - write_teller.Collect(&body); + auto write_teller = ir::CollectTensorNeedsWrite(&body); auto tensors = CollectAllTensorReference(with_expr_gen_tensor); std::sort(tensors.begin(), @@ -224,7 +222,7 @@ void _LoweredFunc_::PrepareBufferCastExprs(bool with_expr_gen_tensor) { if (!tensor->buffer.defined()) continue; Type value_type = tensor->type().ElementOf(); - bool is_const = !write_teller.IsWrite(tensor->name); + bool is_const = !write_teller.count(tensor->name); value_type.set_cpp_handle(); value_type.set_cpp_const(is_const); Var variable = _Var_::Make(tensor->name, value_type); @@ -250,8 +248,7 @@ std::vector _LoweredFunc_::CudaAliasVarExprs() const { } // collect write. std::vector res; - optim::TensorWriteTeller write_teller; - write_teller.Collect(&body); + auto write_teller = ir::CollectTensorNeedsWrite(&body); auto tensors = CollectAllTensorReference(); std::sort(tensors.begin(), @@ -269,7 +266,7 @@ std::vector _LoweredFunc_::CudaAliasVarExprs() const { continue; } Type value_type = tensor->type().ElementOf(); - bool is_const = !write_teller.IsWrite(tensor->name); + bool is_const = !write_teller.count(tensor->name); value_type.set_cpp_handle(); value_type.set_cpp_const(is_const); Var variable = _Var_::Make(tensor->name, value_type); diff --git a/paddle/cinn/ir/operation.cc b/paddle/cinn/ir/operation.cc index 44b1af64fe6b0..9dff3b5e0a5f9 100644 --- a/paddle/cinn/ir/operation.cc +++ b/paddle/cinn/ir/operation.cc @@ -49,10 +49,12 @@ Operation ComputeOp::Make(const std::string &name, n->reduce_axis = reduce_axis; n->tag = tag; n->attrs = attrs; - auto axis = common::GenDefaultAxis(domain.size()); - std::vector _axis; - for (auto &x : axis) _axis.push_back(x); - n->body = {handle(_axis)}; + n->axis = common::GenDefaultAxis(domain.size()); + std::vector tmp_axis; + for (auto &x : n->axis) { + tmp_axis.push_back(x); + } + n->body = {handle(tmp_axis)}; n->reduce_axis = reduce_axis; return Operation(n); } diff --git a/paddle/cinn/ir/operation.h b/paddle/cinn/ir/operation.h index 651c2a9a9dc5c..cdc5175830e38 100644 --- a/paddle/cinn/ir/operation.h +++ b/paddle/cinn/ir/operation.h @@ -105,6 +105,8 @@ struct BufferShareOp : public _Operation_ { */ struct ComputeOp : public _Operation_ { using handle_t = std::function &)>; + //! Var on each dimension + std::vector axis; //! Var on each reduction axis, if the body is a Reduction. std::vector reduce_axis; //! Shape of the output. diff --git a/paddle/cinn/ir/tensor.cc b/paddle/cinn/ir/tensor.cc index 2bfa6ee7737ef..7631141d115cd 100644 --- a/paddle/cinn/ir/tensor.cc +++ b/paddle/cinn/ir/tensor.cc @@ -16,6 +16,7 @@ #include +#include "paddle/cinn/ast_gen_ius/tensor_group.h" #include "paddle/cinn/cinn.h" #include "paddle/cinn/common/arithmatic.h" #include "paddle/cinn/common/axis.h" @@ -250,6 +251,11 @@ Expr *_Tensor_::mutable_body() { CINN_NOT_IMPLEMENTED } +ir::Tensor _Tensor_::InitReduction( + ast_gen_ius::TensorGroup *tensor_group) const { + return tensor_group->MarkReduceInit(this->name); +} + ir::Tensor _Tensor_::InitReduction(poly::StageMap stages, const Target &target) const { CHECK(contains_reduce_axis()) diff --git a/paddle/cinn/ir/tensor.h b/paddle/cinn/ir/tensor.h index 8879e35afa98d..fd8e79f73ffdd 100644 --- a/paddle/cinn/ir/tensor.h +++ b/paddle/cinn/ir/tensor.h @@ -25,6 +25,7 @@ #include #include +#include "paddle/cinn/ast_gen_ius/tensor_group.h" #include "paddle/cinn/common/graph_utils.h" #include "paddle/cinn/ir/buffer.h" #include "paddle/cinn/ir/function_base.h" @@ -33,28 +34,13 @@ namespace cinn { -namespace ir { -class Tensor; -} // namespace ir - -namespace lang { -template -struct Placeholder; - -void InitReduceTensor(poly::StageMap stages, - const ir::Tensor& tensor, - const Target& target = common::DefaultHostTarget()); -} // namespace lang +namespace ast_gen_ius { +class TensorGroup; +} // namespace ast_gen_ius namespace ir { -namespace detail { -constexpr bool LE(int a, int b) { return a <= b; } -constexpr bool GE(int a, int b) { return a >= b; } - -} // namespace detail class _Tensor_; -class Tensor; class Tensor : public ir::IrNodeRef { public: @@ -84,8 +70,8 @@ class Tensor : public ir::IrNodeRef { return operator()(std::vector({a})); } template - inline typename std::enable_if::type - operator()(Args&&... args) const { + inline typename std::enable_if= 2, Expr>::type operator()( + Args&&... args) const { return operator()({std::forward(args)...}); } // @} @@ -288,11 +274,7 @@ class _Tensor_ : public ExprNode<_Tensor_> { poly::StageMap stages, const Target& target = common::DefaultHostTarget()) const; - private: - //! Initialize the axis field after the shape field is assigned. - void InitAxis() const; - - isl::set GenerateIslDomain() const; + ir::Tensor InitReduction(ast_gen_ius::TensorGroup* tensor_group) const; /** * Create the initialization tensor. @@ -304,15 +286,17 @@ class _Tensor_ : public ExprNode<_Tensor_> { poly::StageMap stages, const Target& target = common::DefaultHostTarget()) const; + private: + //! Initialize the axis field after the shape field is assigned. + void InitAxis() const; + + isl::set GenerateIslDomain() const; + //! The names of the tensors depend the same buffer and should schedule before //! this. std::set buffer_depended_tensor_names_; friend Shared CreateStage(Tensor tensor); - - friend void lang::InitReduceTensor(poly::StageMap stages, - const ir::Tensor& tensor, - const Target& target); }; Shared CreateStage(Tensor tensor); diff --git a/paddle/cinn/ir/utils/ir_nodes_collector.cc b/paddle/cinn/ir/utils/ir_nodes_collector.cc index e99da88a1dd35..d44c3701b5ac2 100644 --- a/paddle/cinn/ir/utils/ir_nodes_collector.cc +++ b/paddle/cinn/ir/utils/ir_nodes_collector.cc @@ -207,5 +207,116 @@ std::set CollectReferencedTensors( return ts0; } +std::vector CollectUndefinedVars(const Expr* e) { + struct Mutator : public ir::IRMutator { + using ir::IRMutator::Visit; + std::vector undefined_vars; + std::set defined_vars; + std::set used_vars; + + void CollectVarDef(const std::string& var) { + CHECK(!defined_vars.count(var)) + << "var " << var << " has been defined, please check"; + CHECK(!used_vars.count(var)) + << "var " << var << " is wrongly used before definition"; + defined_vars.insert(var); + } + + void ClearVar(const std::string& var) { + defined_vars.erase(var); + used_vars.erase(var); + } + + void CollectVarUse(const std::string& var) { + used_vars.insert(var); + if (defined_vars.count(var) == 0) { + undefined_vars.push_back(var); + } + } + + void Visit(const ir::Let* op, const Expr* expr) override { + Expr symbol = op->symbol; + auto var = symbol.as_var_ref(); + CHECK(var.defined()); + CollectVarDef(var->name); + auto* node = expr->As(); + Visit(&node->body, &node->body); + } + + void Visit(const ir::For* op, const Expr* expr) override { + CollectVarDef(op->loop_var->name); + auto* node = expr->As(); + Visit(&node->min, &node->min); + Visit(&node->extent, &node->extent); + Visit(&node->body, &node->body); + ClearVar(op->loop_var->name); + } + + void Visit(const ir::Load* op, const Expr* expr) override { + auto tensor = op->tensor.as_tensor_ref(); + CollectVarUse(tensor->name); + auto* node = expr->As(); + for (auto& idx : node->indices) Visit(&idx, &idx); + } + + void Visit(const ir::Store* op, const Expr* expr) override { + auto tensor = op->tensor.as_tensor_ref(); + CollectVarUse(tensor->name); + auto* node = expr->As(); + for (auto& idx : node->indices) Visit(&idx, &idx); + Visit(&node->value, &node->value); + } + + void Visit(const ir::_Var_* op, const Expr* expr) override { + CollectVarUse(op->name); + auto* node = expr->As(); + if (node->lower_bound.defined()) { + Visit(&node->lower_bound, &node->lower_bound); + } + if (node->upper_bound.defined()) { + Visit(&node->upper_bound, &node->upper_bound); + } + } + + void Visit(const ir::Reduce* op, const Expr* expr) override { + for (auto& axis : op->reduce_axis) { + CollectVarDef(axis->name); + } + auto* node = expr->As(); + if (node->init.defined()) Visit(&node->init, &node->init); + Visit(&node->body, &node->body); + } + }; + + Mutator mutator; + mutator.Visit(e, e); + return mutator.undefined_vars; +} + +std::set CollectTensorNeedsWrite(const Expr* e) { + std::set tensor_written; + IrNodesCollector::handler_t handler = [&](const Expr* x) { + if (x->As()) { + tensor_written.insert( + x->As()->tensor.As()->name); + } + if (x->As()) { + tensor_written.insert(x->As()->name); + } + }; + IrNodesCollector::teller_t teller = [](const Expr* x) { + if (x->As() && x->As()->tensor.As()) { + return true; + } + if (x->As() && x->As()->is_call_node()) { + return true; + } + return false; + }; + IrNodesCollector collector(std::move(teller), std::move(handler), false); + collector.Visit(e); + return tensor_written; +} + } // namespace ir } // namespace cinn diff --git a/paddle/cinn/ir/utils/ir_nodes_collector.h b/paddle/cinn/ir/utils/ir_nodes_collector.h old mode 100755 new mode 100644 index 75ed3fa9e64f4..0f8a390e1ade7 --- a/paddle/cinn/ir/utils/ir_nodes_collector.h +++ b/paddle/cinn/ir/utils/ir_nodes_collector.h @@ -65,5 +65,24 @@ std::map CollectTensorMap( return true; }); +/** + * Collect undefined vars in the scope. + * + * e.g. + * + * The expression: + * for i + * for j + * a[i, j] = b[i, j] + * + * here a, b are vars without definition + */ +std::vector CollectUndefinedVars(const Expr* e); + +/** + * Collect the Tensor Nodes which will be Writed by Store or Call Nodes + */ +std::set CollectTensorNeedsWrite(const Expr* e); + } // namespace ir } // namespace cinn diff --git a/paddle/cinn/lang/CMakeLists.txt b/paddle/cinn/lang/CMakeLists.txt index f4ef9e6d7b103..62d91e8103c7a 100644 --- a/paddle/cinn/lang/CMakeLists.txt +++ b/paddle/cinn/lang/CMakeLists.txt @@ -7,6 +7,8 @@ gather_srcs( compute.cc placeholder.cc lower.cc + lower_impl.cc + lower_tensor_group.cc builtin.cc lower_impl.cc packed_func.cc) diff --git a/paddle/cinn/lang/lower.cc b/paddle/cinn/lang/lower.cc old mode 100755 new mode 100644 index 1661f65975c8f..667c0646c43cd --- a/paddle/cinn/lang/lower.cc +++ b/paddle/cinn/lang/lower.cc @@ -24,12 +24,14 @@ #include "paddle/cinn/ir/buffer.h" #include "paddle/cinn/ir/utils/ir_printer.h" #include "paddle/cinn/lang/lower_impl.h" +#include "paddle/cinn/lang/lower_tensor_group.h" #include "paddle/cinn/optim/optimize.h" #include "paddle/cinn/utils/string.h" namespace cinn { namespace lang { +using ast_gen_ius::TensorGroup; using ir::Tensor; using poly::Stage; @@ -84,6 +86,49 @@ std::vector GetArgs( return res; } +//! Collect the temporary tensors from a computational graph. +std::vector GetTempBuffers(const std::vector& tensor_args, + const TensorGroup& tensor_group, + Expr body) { + std::unordered_set tensor_arg_names; + std::unordered_set buffer_arg_names; + for (auto& tensor : tensor_args) { + tensor_arg_names.insert(tensor->name); + if (tensor->buffer.defined()) { + buffer_arg_names.insert(tensor->buffer->name); + } + } + std::map + name_to_buffer; // used to avoid duplication. + + auto all_temp_tensors = + ir::CollectIRNodesWithoutTensor(body, [&](const Expr* x) { + return x->as_tensor() && x->as_tensor()->buffer.defined() && + (!tensor_group.Contain(x->as_tensor()->name) && + ((!buffer_arg_names.count(x->as_tensor()->buffer->name) && + !tensor_arg_names.count(x->as_tensor()->name)) || + utils::Endswith(x->as_tensor()->buffer->name, "temp_buffer"))); + }); + for (auto& e : all_temp_tensors) { + auto buffer_name = e.as_tensor()->buffer->name; + if (!name_to_buffer.count(buffer_name)) { + name_to_buffer[buffer_name] = e.as_tensor()->buffer; + } else { + // Just copy from old code, but why? + if (e.as_tensor()->buffer->numel() < + name_to_buffer[buffer_name]->numel()) { + name_to_buffer[buffer_name] = e.as_tensor()->buffer; + } + } + } + + std::vector temp_buffers; + for (auto& i : name_to_buffer) { + temp_buffers.push_back(i.second); + } + return temp_buffers; +} + //! Collect the temporary tensors from a computational graph. std::vector GetTempBuffers(const std::vector& tensor_args, const poly::StageMap& stage_map, @@ -198,6 +243,25 @@ std::set CollectTempTensorsFromCtrlDepends( return res; } +void InitReduceTensor(TensorGroup* tensor_group, + const Tensor& tensor, + const Target& target) { + if (tensor->is_reduce_tensor()) { + tensor_group->MarkReduceInit(tensor->name); + } + auto uninited_reduce_tensors = + ir::CollectIRNodes(tensor->body(), [&](const Expr* x) { + return x && x->defined() && x->as_tensor() && + x->as_tensor()->is_reduce_tensor() && + !tensor_group->HasMarkedReduceInit(x->as_tensor()->name); + }); + for (auto& t : uninited_reduce_tensors) { + std::string reduce_name = t.as_tensor()->name; + VLOG(3) << "Init reduce tensor: " << reduce_name; + tensor_group->MarkReduceInit(reduce_name); + } +} + void InitReduceTensor(StageMap stages, const Tensor& tensor, const Target& target) { @@ -216,6 +280,63 @@ void InitReduceTensor(StageMap stages, } } +std::set CollectTempTensorsFromCtrlDepends( + ast_gen_ius::TensorGroup* tensor_group, + const std::vector& tensor_args) { + std::set res; + for (const ir::Tensor& a : tensor_group->GetAllTensors()) { + for (const ir::Tensor& t : tensor_group->GetCrtlDepTensors(a->name)) { + res.emplace(t); + } + } + for (const ir::Tensor& t : tensor_args) { + if (res.count(t)) { + res.erase(t); + } + } + return res; +} + +ir::LoweredFunc LowerToAst(const std::string& name, + const std::vector& tensor_args, + ast_gen_ius::TensorGroup* tensor_group, + const Target& target) { + // Init the reduce tensors first before any process. + for (auto& t : tensor_args) { + InitReduceTensor(tensor_group, t, target); + } + // Merge the ctrl_deps with the given temp_tensors ang get a new temp_tensors + std::set ctrl_deps = + CollectTempTensorsFromCtrlDepends(tensor_group, tensor_args); + std::vector group_vec = {tensor_group}; + auto lower_instance = detail::LowerTensorGroup( + name, + tensor_args, + {}, + group_vec, + std::vector(ctrl_deps.begin(), ctrl_deps.end()), + target); + std::vector result = lower_instance(); + for (auto& res : result) { + if (target == common::DefaultNVGPUTarget()) { + res->device_api = ir::DeviceAPI::GPU; + } + } + return result[0]; +} + +std::vector LowerToAstVec( + const std::string& name, + const std::vector& tensor_args, + std::vector tensor_groups, + const Target& target) { + std::vector ret; + for (ast_gen_ius::TensorGroup* tg : tensor_groups) { + ret.push_back(LowerToAst(name, tensor_args, tg, target)); + } + return ret; +} + ir::LoweredFunc Lower(const std::string& name, StageMap stages, const std::vector& tensor_args, diff --git a/paddle/cinn/lang/lower.h b/paddle/cinn/lang/lower.h index af8a186583a69..c80d4bc769cdf 100644 --- a/paddle/cinn/lang/lower.h +++ b/paddle/cinn/lang/lower.h @@ -20,6 +20,7 @@ #include #include +#include "paddle/cinn/ast_gen_ius/tensor_group.h" #include "paddle/cinn/ir/ir.h" #include "paddle/cinn/ir/lowered_func.h" #include "paddle/cinn/ir/module.h" @@ -73,6 +74,22 @@ std::vector LowerVec( const Target &target = common::DefaultHostTarget(), bool support_ir_schedule = false); +ir::LoweredFunc LowerToAst(const std::string &name, + const std::vector &tensor_args, + ast_gen_ius::TensorGroup *tensor_group, + const Target &target = common::DefaultHostTarget()); + +std::vector LowerToAstVec( + const std::string &name, + const std::vector &tensor_args, + std::vector tensor_groups, + const Target &target = common::DefaultHostTarget()); + +std::vector GetTempBuffers( + const std::vector &tensor_args, + const ast_gen_ius::TensorGroup &tensor_group, + Expr body); + std::vector GetArgs( const Expr &func_body, const std::vector &input_output_nodes); diff --git a/paddle/cinn/lang/lower_impl.cc b/paddle/cinn/lang/lower_impl.cc index f313d52938a93..629b405dcd2f0 100644 --- a/paddle/cinn/lang/lower_impl.cc +++ b/paddle/cinn/lang/lower_impl.cc @@ -25,7 +25,7 @@ #include "paddle/cinn/ir/ir_base.h" #include "paddle/cinn/ir/tensor.h" #include "paddle/cinn/ir/utils/ir_printer.h" -#include "paddle/cinn/optim/remove_nested_block.h" +#include "paddle/cinn/optim/ir_simplify.h" #include "paddle/cinn/optim/replace_var_with_expr.h" #include "paddle/cinn/optim/transform_polyfor_to_for.h" #include "paddle/cinn/poly/stage.h" @@ -342,8 +342,7 @@ std::vector LowerImpl::GenerateFunctionArgumentList( CheckArgsUnique(); std::vector args; - optim::TensorWriteTeller teller; - teller.Collect(&fn_body); + auto teller = ir::CollectTensorNeedsWrite(&fn_body); std::set arg_names; @@ -358,7 +357,7 @@ std::vector LowerImpl::GenerateFunctionArgumentList( for (auto& tensor : tensor_args_) { auto* tensor_node = tensor.As(); - bool is_output = teller.IsWrite(tensor->name); + bool is_output = teller.count(tensor->name); VLOG(1) << "tensor argument " << tensor->name << " buffer " << tensor->buffer->name; @@ -396,8 +395,7 @@ std::vector LowerImpl::GenFuncArgForSplitKernel( std::vector in_args; std::vector out_args; - optim::TensorWriteTeller teller; - teller.Collect(&func_iterator); + auto teller = ir::CollectTensorNeedsWrite(&func_iterator); std::set arg_names; std::set all_tensor_names; @@ -448,7 +446,7 @@ std::vector LowerImpl::GenFuncArgForSplitKernel( VLOG(3) << "In tensor_args_, it has : " << tensor->name; if (temp_tensor_names.count(tensor->name) > 0) continue; if (all_tensor_names.count(tensor->name) == 0) continue; - bool is_output = teller.IsWrite(tensor->name); + bool is_output = teller.count(tensor->name); VLOG(3) << "tensor argument " << tensor->name << " buffer " << tensor->buffer->name; @@ -485,7 +483,7 @@ std::vector LowerImpl::GenFuncArgForSplitKernel( VLOG(3) << "Tensor " << tensor->name; if (tensor->buffer.defined() && !arg_names.count(tensor->buffer->name)) { bool is_output = - teller.IsWrite(tensor->name) && teller.IsWrite(tensor->name); + teller.count(tensor->name) && teller.count(tensor->name); if (is_output) out_args.emplace_back(tensor->buffer, ir::Argument::IO::kOutput); } @@ -655,7 +653,7 @@ std::vector LowerImpl::operator()() { if (support_ir_schedule_) { optim::TransformPolyForToFor(&func->body); - optim::RemoveNestedBlock(&func->body); + optim::SimplifyBlocks(&func->body); func->body = ir::Block::Make({func->body}); result.push_back(ir::LoweredFunc(func.get())); num_func++; diff --git a/paddle/cinn/lang/lower_impl.h b/paddle/cinn/lang/lower_impl.h index 3e52279b19566..c5bfdfb1fb74d 100644 --- a/paddle/cinn/lang/lower_impl.h +++ b/paddle/cinn/lang/lower_impl.h @@ -27,14 +27,13 @@ #include "paddle/cinn/common/graph_utils.h" #include "paddle/cinn/ir/buffer.h" +#include "paddle/cinn/ir/utils/ir_mutator.h" #include "paddle/cinn/ir/utils/ir_printer.h" #include "paddle/cinn/optim/buffer_assign.h" #include "paddle/cinn/optim/compute_inline_expand.h" #include "paddle/cinn/optim/fold_cinn_call_arguments.h" #include "paddle/cinn/optim/optimize.h" -#include "paddle/cinn/optim/remove_nested_block.h" #include "paddle/cinn/optim/replace_call_with_expr.h" -#include "paddle/cinn/optim/tensor_write_tell.h" #include "paddle/cinn/optim/transform_gpu_forloop.h" #include "paddle/cinn/optim/transform_polyfor_to_for.h" #include "paddle/cinn/poly/ast_gen.h" diff --git a/paddle/cinn/lang/lower_tensor_group.cc b/paddle/cinn/lang/lower_tensor_group.cc new file mode 100644 index 0000000000000..6fb8e72f43c68 --- /dev/null +++ b/paddle/cinn/lang/lower_tensor_group.cc @@ -0,0 +1,215 @@ +// Copyright (c) 2023 CINN Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/cinn/lang/lower_tensor_group.h" + +#include +#include +#include +#include + +#include "paddle/cinn/ast_gen_ius/ast_gen.h" +#include "paddle/cinn/ast_gen_ius/tensor_group.h" +#include "paddle/cinn/common/common.h" +#include "paddle/cinn/common/context.h" +#include "paddle/cinn/common/ir_util.h" +#include "paddle/cinn/ir/ir_base.h" +#include "paddle/cinn/ir/tensor.h" +#include "paddle/cinn/ir/utils/ir_mutator.h" +#include "paddle/cinn/ir/utils/ir_printer.h" +#include "paddle/cinn/optim/ir_simplify.h" +#include "paddle/cinn/optim/replace_var_with_expr.h" +#include "paddle/cinn/optim/transform_polyfor_to_for.h" +#include "paddle/cinn/poly/stage.h" + +namespace cinn { +namespace lang { +namespace detail { + +LowerTensorGroup::LowerTensorGroup( + const std::string& fn_name, + const std::vector& tensor_args, + const std::vector& scalar_args, + const std::vector& tensor_groups, + const std::vector& temp_tensor_args, + const Target& target) + : fn_name_(fn_name), + tensor_args_(tensor_args), + scalar_args_(scalar_args), + tensor_groups_(tensor_groups), + temp_tensor_args_(temp_tensor_args), + target_(target) {} + +std::vector LowerTensorGroup::operator()() { + std::vector result; + int num_func = 0; + for (ast_gen_ius::TensorGroup* tensor_group : tensor_groups_) { + // 1. Generate function body + ir::Expr func_body = GenerateFunctionBody(tensor_group); + // 2. Assign buffer to tensors + auto tensor_map = tensor_group->AllocateBuffers(); + // copy the tensor(with buffer assigned) back to func's args. + for (auto& arg : tensor_args_) { + if (arg->is_placeholder_node() || arg->buffer.defined()) { + continue; + } + if (arg->body().As() && arg->body().type().is_void()) { + continue; // extern call + } + + if (tensor_map.find(arg->name) == tensor_map.end()) { + LOG(INFO) << "Didn't find arg tensor " << arg->name + << "in tensor_map.\n" + << "The function is " << fn_name_ + << "\nAnd all the arg tensors are:\n"; + for (auto& i : tensor_args_) { + LOG(INFO) << i->name; + } + LOG(FATAL) << "Fatal Error!"; + } + Reference(&arg)->buffer = tensor_map.at(arg->name)->buffer; + } + + // 3. Collect temp tensor buffers + std::set temp_tensor_names; + for (auto& t : temp_tensor_args_) { + temp_tensor_names.insert(t->name); + } + + // Some store tensors are also temp tensors; + auto store_exprs = ir::CollectIRNodes( + func_body, [](const Expr* x) { return x->As(); }); + for (auto& expr : store_exprs) { + auto* store_node = expr.As(); + CHECK(store_node); + auto* tensor = store_node->tensor.As(); + CHECK(tensor); + VLOG(3) << "In store_exprs, its name is : " << tensor->name; + CHECK(tensor->buffer.defined()); + if (tensor->buffer->memory_type != ir::MemoryType::Heap) { + temp_tensor_names.insert(store_node->tensor.as_tensor_ref()->name); + } + } + + std::vector temp_buffers; + std::unordered_set buffer_name_set; + for (const std::string& name : temp_tensor_names) { + if (!tensor_map.count(name)) { + continue; + } + ir::Tensor& t = tensor_map[name]; + if (t->buffer.defined() && !buffer_name_set.count(t->buffer->name)) { + temp_buffers.push_back(t->buffer); + buffer_name_set.insert(t->buffer->name); + } + } + + // 4. Handle function args + std::vector func_args = + GenerateFunctionArgumentList(func_body); + + // 5. Actual function make + std::string actual_fn_name = fn_name_; + if (num_func > 0) { + actual_fn_name += "_" + std::to_string(num_func); + VLOG(3) << "Making func :" << actual_fn_name; + } + for (auto& i : func_args) { + VLOG(3) << "func_args is : " << i.name(); + } + for (auto& i : temp_buffers) { + VLOG(3) << "temp_buffers is : " << i->name; + } + ir::LoweredFunc func = ir::_LoweredFunc_::Make( + actual_fn_name, func_args, func_body, temp_buffers); + + // 6. Final clean up + optim::SimplifyBlocks(&func->body); + func->body = ir::Block::Make({func->body}); + result.push_back(ir::LoweredFunc(func.get())); + num_func++; + } + return result; +} + +std::vector LowerTensorGroup::GenerateFunctionArgumentList( + Expr fn_body) { + std::vector args; + auto teller = ir::CollectTensorNeedsWrite(&fn_body); + + std::set arg_names; + + for (auto& scalar : scalar_args_) { + CHECK(!arg_names.count(scalar->name)); + auto* scalar_node = scalar.As(); + CHECK(scalar_node->type().valid()); + arg_names.insert(scalar->name); + + args.emplace_back(scalar, ir::Argument::IO::kInput); + } + + for (auto& tensor : tensor_args_) { + auto* tensor_node = tensor.As(); + bool is_output = teller.count(tensor->name); + VLOG(6) << "tensor argument " << tensor->name << ", buffer " + << tensor->buffer->name << ", is output: " << is_output; + + // avoid duplicate + if (!tensor_node->buffer.defined()) { + continue; + } + // if a argument is already marked as kInput, mark it as kOutput and move it + // to the back. + if (arg_names.count(tensor_node->buffer->name)) { + auto it = + std::find_if(args.begin(), args.end(), [&](const ir::Argument& x) { + return x.name() == tensor_node->buffer->name; + }); + CHECK(it != args.end()); + if (it->is_input()) { + args.erase(it); + } else if (it->is_output()) { + continue; + } + } + + arg_names.insert(tensor_node->buffer->name); + + auto io = is_output ? ir::Argument::IO::kOutput : ir::Argument::IO::kInput; + VLOG(6) << "Collect " << (is_output ? "W" : "R") << " argument " + << tensor->buffer->name; + args.emplace_back(tensor_node->buffer, io); + } + + return args; +} + +ir::Expr LowerTensorGroup::GenerateFunctionBody( + ast_gen_ius::TensorGroup* tensor_group) { + std::vector ordered_tensors = + tensor_group->GetGenFuncTopoOrder(tensor_args_); + std::vector bodies; + for (const ir::Tensor& tensor : ordered_tensors) { + bodies.emplace_back(ast_gen_ius::AstGen::Build(tensor)); + } + if (bodies.size() == 1) { + return bodies[0]; + } + + return ir::Block::Make(bodies); +} + +} // namespace detail +} // namespace lang +} // namespace cinn diff --git a/paddle/cinn/lang/lower_tensor_group.h b/paddle/cinn/lang/lower_tensor_group.h new file mode 100644 index 0000000000000..ce7f1f7c7cdc9 --- /dev/null +++ b/paddle/cinn/lang/lower_tensor_group.h @@ -0,0 +1,74 @@ +// Copyright (c) 2023 CINN Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once +#include + +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include "paddle/cinn/ast_gen_ius/tensor_group.h" +#include "paddle/cinn/common/graph_utils.h" +#include "paddle/cinn/ir/buffer.h" +#include "paddle/cinn/ir/utils/ir_printer.h" +#include "paddle/cinn/optim/buffer_assign.h" +#include "paddle/cinn/optim/compute_inline_expand.h" +#include "paddle/cinn/optim/fold_cinn_call_arguments.h" +#include "paddle/cinn/optim/optimize.h" +#include "paddle/cinn/optim/replace_call_with_expr.h" +#include "paddle/cinn/optim/transform_gpu_forloop.h" +#include "paddle/cinn/optim/transform_polyfor_to_for.h" +#include "paddle/cinn/poly/ast_gen.h" + +namespace cinn { +namespace lang { +namespace detail { + +class LowerTensorGroup { + public: + LowerTensorGroup(const std::string& fn_name, + const std::vector& tensor_args, + const std::vector& scalar_args, + const std::vector& tensor_groups, + const std::vector& temp_tensor_args = {}, + const Target& target = common::DefaultHostTarget()); + + std::vector operator()(); + + ir::Expr GenerateFunctionBody(ast_gen_ius::TensorGroup* tensor_group); + + std::vector GenerateFunctionArgumentList(ir::Expr fn_body); + + private: + const std::string& fn_name_; + const std::vector& tensor_args_; + const std::vector& scalar_args_; + std::vector temp_tensor_args_; + std::vector tensor_groups_; + Target target_; + + //! CUDA axis info for this function. + std::vector cuda_axis_info_; +}; + +} // namespace detail +} // namespace lang +} // namespace cinn diff --git a/paddle/cinn/lang/lower_test.cc b/paddle/cinn/lang/lower_test.cc index 14f81090e30cb..431d73d075be6 100755 --- a/paddle/cinn/lang/lower_test.cc +++ b/paddle/cinn/lang/lower_test.cc @@ -18,6 +18,7 @@ #include +#include "paddle/cinn/ast_gen_ius/tensor_group.h" #include "paddle/cinn/cinn.h" #include "paddle/cinn/lang/buffer.h" #include "paddle/cinn/lang/compute.h" @@ -27,6 +28,10 @@ namespace cinn { namespace lang { +#define TEST_SOUTPUT(x, out) \ + LOG(INFO) << "\n" << x << std::endl; \ + EXPECT_EQ(utils::GetStreamCnt(x), utils::Trim(out)); + TEST(lower, basic) { auto M = Expr(100); auto N = Expr(15); @@ -42,10 +47,6 @@ TEST(lower, basic) { LOG(INFO) << "lower_size " << lower_funcs; -#define TEST_SOUTPUT(x, out) \ - std::cout << "\n" << x << std::endl; \ - EXPECT_EQ(utils::GetStreamCnt(x), utils::Trim(out)); - auto out = R"ROC( { serial for (i, 0, 100) @@ -77,7 +78,7 @@ TEST(lower, more_complex) { auto lower_funcs = Lower("cal_C", stages, {A, B, C}); - std::cout << "func:\n" << Expr(lower_funcs->self()) << std::endl; + LOG(INFO) << "func:\n" << Expr(lower_funcs->self()) << std::endl; } //! To support training, the dynamic shape support is vital. We test the @@ -157,5 +158,34 @@ TEST(lower, temp_buffer_collects) { } } +TEST(lower_to_ast, basic) { + auto M = Expr(100); + auto N = Expr(15); + + Placeholder A("A", {Expr(M), Expr(N)}); + + ir::Tensor B = Compute( + {M, N}, [=](Var i, Var j) -> Expr { return A(i, j) + 1.f; }, "B"); + + ast_gen_ius::TensorGroup tensor_group({B}); + + auto lower_funcs = LowerToAst("cal_B", {A, B}, &tensor_group); + + LOG(INFO) << "lower_func " << lower_funcs; + + auto out = R"ROC( +{ + serial for (i, 0, 100) + { + serial for (j, 0, 15) + { + B[i, j] = (A[i, j] + 1.00000000f) + } + } +} +)ROC"; + TEST_SOUTPUT(lower_funcs->body, out); +} + } // namespace lang } // namespace cinn diff --git a/paddle/cinn/optim/CMakeLists.txt b/paddle/cinn/optim/CMakeLists.txt index 99ae9cf3bd3d6..1b4a55479ef0b 100755 --- a/paddle/cinn/optim/CMakeLists.txt +++ b/paddle/cinn/optim/CMakeLists.txt @@ -3,11 +3,9 @@ core_gather_headers() gather_srcs( cinnapi_src SRCS - remove_nested_block.cc replace_call_with_expr.cc ir_replace.cc replace_var_with_expr.cc - tensor_write_tell.cc ir_simplify.cc optimize.cc vectorize_loops.cc @@ -25,7 +23,6 @@ gather_srcs( replace_const_param_to_integer.cc lower_intrin.cc cast_bool_to_int8.cc - collect_undefined_vars.cc var_mod_simplify.cc remove_schedule_block.cc) @@ -33,8 +30,6 @@ if(WITH_CUDA) gather_srcs(cinnapi_src SRCS transform_gpu_forloop.cc) endif() -cinn_cc_test(test_remove_nested_block SRCS remove_nested_block_test.cc DEPS - cinncore) cinn_cc_test(test_ir_simplify SRCS ir_simplify_test.cc DEPS cinncore) cinn_cc_test(test_replace_call_with_expr SRCS replace_call_with_expr_test.cc DEPS cinncore) diff --git a/paddle/cinn/optim/collect_undefined_vars.cc b/paddle/cinn/optim/collect_undefined_vars.cc deleted file mode 100644 index 2f925d1333f39..0000000000000 --- a/paddle/cinn/optim/collect_undefined_vars.cc +++ /dev/null @@ -1,111 +0,0 @@ -// Copyright (c) 2021 CINN Authors. All Rights Reserved. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#include "paddle/cinn/optim/collect_undefined_vars.h" - -#include - -#include "paddle/cinn/ir/utils/ir_mutator.h" - -namespace cinn::optim { - -namespace { -struct Mutator : public ir::IRMutator<> { - using ir::IRMutator<>::Visit; - std::vector undefined_vars; - std::set defined_vars; - std::set used_vars; - - void CollectVarDef(const std::string& var) { - CHECK(!defined_vars.count(var)) - << "var " << var << " has been defined, please check"; - CHECK(!used_vars.count(var)) - << "var " << var << " is wrongly used before definition"; - defined_vars.insert(var); - } - - void ClearVar(const std::string& var) { - defined_vars.erase(var); - used_vars.erase(var); - } - - void CollectVarUse(const std::string& var) { - used_vars.insert(var); - if (defined_vars.count(var) == 0) { - undefined_vars.push_back(var); - } - } - - void Visit(const ir::Let* op, Expr* expr) final { - Expr symbol = op->symbol; - auto var = symbol.as_var_ref(); - CHECK(var.defined()); - CollectVarDef(var->name); - auto* node = expr->As(); - Visit(&node->body, &node->body); - } - - void Visit(const ir::For* op, Expr* expr) final { - CollectVarDef(op->loop_var->name); - auto* node = expr->As(); - Visit(&node->min, &node->min); - Visit(&node->extent, &node->extent); - Visit(&node->body, &node->body); - ClearVar(op->loop_var->name); - } - - void Visit(const ir::Load* op, Expr* expr) final { - auto tensor = op->tensor.as_tensor_ref(); - CollectVarUse(tensor->name); - auto* node = expr->As(); - for (auto& idx : node->indices) Visit(&idx, &idx); - } - - void Visit(const ir::Store* op, Expr* expr) final { - auto tensor = op->tensor.as_tensor_ref(); - CollectVarUse(tensor->name); - auto* node = expr->As(); - for (auto& idx : node->indices) Visit(&idx, &idx); - Visit(&node->value, &node->value); - } - - void Visit(const ir::_Var_* op, Expr* expr) final { - CollectVarUse(op->name); - auto* node = expr->As(); - if (node->lower_bound.defined()) { - Visit(&node->lower_bound, &node->lower_bound); - } - if (node->upper_bound.defined()) { - Visit(&node->upper_bound, &node->upper_bound); - } - } - - void Visit(const ir::Reduce* op, Expr* expr) final { - for (auto& axis : op->reduce_axis) { - CollectVarDef(axis->name); - } - auto* node = expr->As(); - if (node->init.defined()) Visit(&node->init, &node->init); - Visit(&node->body, &node->body); - } -}; -} // namespace - -std::vector CollectUndefinedVars(Expr* e) { - Mutator mutator; - mutator.Visit(e, e); - return mutator.undefined_vars; -} - -} // namespace cinn::optim diff --git a/paddle/cinn/optim/collect_undefined_vars.h b/paddle/cinn/optim/collect_undefined_vars.h deleted file mode 100644 index b83620fcc1cb0..0000000000000 --- a/paddle/cinn/optim/collect_undefined_vars.h +++ /dev/null @@ -1,36 +0,0 @@ -// Copyright (c) 2021 CINN Authors. All Rights Reserved. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#pragma once -#include -#include - -#include "paddle/cinn/ir/ir.h" -namespace cinn::optim { - -/** - * Collect undefined vars in the scope. - * - * e.g. - * - * The expression: - * for i - * for j - * a[i, j] = b[i, j] - * - * here a, b are vars without definition - */ -std::vector CollectUndefinedVars(Expr* e); - -} // namespace cinn::optim diff --git a/paddle/cinn/optim/ir_simplify.cc b/paddle/cinn/optim/ir_simplify.cc index bfed498da521d..6cf3fcf4b7be8 100644 --- a/paddle/cinn/optim/ir_simplify.cc +++ b/paddle/cinn/optim/ir_simplify.cc @@ -305,6 +305,50 @@ struct SimplifyBlocksMutator : public ir::IRMutator<> { expr->As()->stmts = stmts; } } + + void Visit(const IfThenElse* op, Expr* expr) override { + auto* node = expr->As(); + Visit(&node->condition, &node->condition); + if (node->true_case.As() && + (node->true_case.As()->stmts.size() == 1)) { + node->true_case = node->true_case.As()->stmts[0]; + } + Visit(&node->true_case, &node->true_case); + if (node->false_case.defined()) { + if (node->false_case.As() && + (node->false_case.As()->stmts.size() == 1)) { + node->false_case = node->false_case.As()->stmts[0]; + } + Visit(&node->false_case, &node->false_case); + } + } + + void Visit(const ScheduleBlock* op, Expr* expr) override { + auto* node = expr->As(); + CHECK(node); + for (auto& var : node->iter_vars) { + if (var->lower_bound.defined()) { + Visit(&var->lower_bound, &var->lower_bound); + } + if (var->upper_bound.defined()) { + Visit(&var->upper_bound, &var->upper_bound); + } + } + for (auto& buffer_region : node->read_buffers) { + Visit(&buffer_region, &buffer_region); + } + for (auto& buffer_region : node->write_buffers) { + Visit(&buffer_region, &buffer_region); + } + + if (node->body.As()) { + if (node->body.As()->stmts.size() == 1) { + node->body = node->body.As()->stmts[0]; + } + } + + Visit(&(node->body), &(node->body)); + } }; struct SimplifyForLoopsMutator : public ir::IRMutator<> { diff --git a/paddle/cinn/optim/optimize.cc b/paddle/cinn/optim/optimize.cc index b1e73e3c58a9b..3764e1bd616e2 100644 --- a/paddle/cinn/optim/optimize.cc +++ b/paddle/cinn/optim/optimize.cc @@ -27,7 +27,6 @@ #include "paddle/cinn/optim/lower_function_call_bind_vars.h" #include "paddle/cinn/optim/lower_intrin.h" #include "paddle/cinn/optim/map_extern_call.h" -#include "paddle/cinn/optim/remove_nested_block.h" #include "paddle/cinn/optim/remove_schedule_block.h" #include "paddle/cinn/optim/replace_const_param_to_integer.h" #include "paddle/cinn/optim/transform_gpu_forloop.h" @@ -65,8 +64,8 @@ Expr Optimize(Expr e, CudaSyncThreadsDropIfThenElse(&copied); #endif - RemoveNestedBlock(&copied); - VLOG(4) << "After Optimize RemoveNestedBlock:" << copied; + SimplifyBlocks(&copied); + VLOG(4) << "After SimplifyBlocks:" << copied; MapExternCall(&copied, target); VLOG(10) << "After Optimize MapExternCall:" << copied; diff --git a/paddle/cinn/optim/remove_nested_block.cc b/paddle/cinn/optim/remove_nested_block.cc deleted file mode 100644 index 06050ec5b123c..0000000000000 --- a/paddle/cinn/optim/remove_nested_block.cc +++ /dev/null @@ -1,123 +0,0 @@ -// Copyright (c) 2021 CINN Authors. All Rights Reserved. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#include "paddle/cinn/optim/remove_nested_block.h" - -#include "paddle/cinn/ir/utils/ir_mutator.h" -#include "paddle/cinn/ir/utils/ir_printer.h" - -namespace cinn { -namespace optim { - -Expr GetExprInsideBlock(Expr op) { - Expr node = op; - while (node.As()) { - auto& stmts = node.As()->stmts; - if (stmts.size() == 1) { - node = stmts.front(); - } else { - break; - } - } - return node; -} - -// This will remove the nested blocks, but it will also remove the block outside -// the forloop's body. -struct NestedBlockSimplifer : public ir::IRMutator { - void operator()(ir::Expr* expr) { Visit(expr); } - - private: - void Visit(ir::Expr* expr) { ir::IRMutator<>::Visit(expr, expr); } - - void Visit(const ir::Block* expr, Expr* op) override { - auto* node = op->As(); - if (node->stmts.size() == 1) { - *op = GetExprInsideBlock(*op); - IRMutator::Visit(op, op); - } else { - IRMutator::Visit(expr, op); - } - } -}; - -struct NestedBlockRemover : public ir::IRMutator { - void operator()(ir::Expr* expr) { Visit(expr); } - - private: - void Visit(ir::Expr* expr) { ir::IRMutator<>::Visit(expr, expr); } - - void Visit(const ir::Block* expr, Expr* op) override { - auto* node = op->As(); - - std::vector new_exprs; - - bool detect_nested = false; - for (auto it = node->stmts.begin(); it != node->stmts.end(); it++) { - auto* block = it->As(); - if (block) { - detect_nested = true; - new_exprs.insert( - std::end(new_exprs), block->stmts.begin(), block->stmts.end()); - } else { - new_exprs.push_back(*it); - } - } - - node->stmts = new_exprs; - - IRMutator::Visit(expr, op); - } -}; - -// add block outside forloop's body. -struct AddBlockToForloop : public ir::IRMutator<> { - void operator()(ir::Expr* expr) { ir::IRMutator<>::Visit(expr, expr); } - - void Visit(const ir::For* expr, Expr* op) override { - auto* node = op->As(); - if (!node->body.As()) { - node->body = ir::Block::Make({node->body}); - } - - ir::IRMutator<>::Visit(expr, op); - } - - void Visit(const ir::PolyFor* expr, Expr* op) override { - auto* node = op->As(); - if (!node->body.As()) { - node->body = ir::Block::Make({node->body}); - } - - ir::IRMutator<>::Visit(expr, op); - } - - void Visit(const ir::_LoweredFunc_* expr, Expr* op) override { - auto* node = op->As(); - if (!node->body.As()) { - node->body = ir::Block::Make({node->body}); - } - - ir::IRMutator<>::Visit(expr, op); - } -}; - -void RemoveNestedBlock(Expr* e) { - NestedBlockRemover()(e); - NestedBlockSimplifer()(e); - AddBlockToForloop()(e); -} - -} // namespace optim -} // namespace cinn diff --git a/paddle/cinn/optim/remove_nested_block_test.cc b/paddle/cinn/optim/remove_nested_block_test.cc deleted file mode 100644 index 27238329dfbd7..0000000000000 --- a/paddle/cinn/optim/remove_nested_block_test.cc +++ /dev/null @@ -1,58 +0,0 @@ -// Copyright (c) 2021 CINN Authors. All Rights Reserved. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#include "paddle/cinn/optim/remove_nested_block.h" - -#include - -#include -#include - -#include "paddle/cinn/ir/utils/ir_printer.h" -#include "paddle/cinn/utils/string.h" - -namespace cinn { -namespace optim { - -TEST(RemoveNestedBlock, basic) { - auto block0 = ir::Block::Make({Expr(1.f), Expr(1.f)}); - auto block1 = ir::Block::Make({block0}); - auto e = Expr(block1); - - std::string origin = utils::GetStreamCnt(e); - EXPECT_EQ(origin, utils::Trim(R"ROC( -{ - { - 1.00000000f - 1.00000000f - } -} - )ROC")); - - std::cout << "origin:\n" << e << std::endl; - - RemoveNestedBlock(&e); - - std::cout << "e:\n" << e << std::endl; - - EXPECT_EQ(utils::GetStreamCnt(e), utils::Trim(R"ROC( -{ - 1.00000000f - 1.00000000f -} - )ROC")); -} - -} // namespace optim -} // namespace cinn diff --git a/paddle/cinn/optim/tensor_write_tell.h b/paddle/cinn/optim/tensor_write_tell.h deleted file mode 100644 index f8ee114561a30..0000000000000 --- a/paddle/cinn/optim/tensor_write_tell.h +++ /dev/null @@ -1,58 +0,0 @@ -// Copyright (c) 2021 CINN Authors. All Rights Reserved. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#pragma once -#include -#include - -#include "paddle/cinn/ir/ir.h" -#include "paddle/cinn/ir/utils/ir_mutator.h" - -namespace cinn { -namespace optim { - -struct TensorWriteTeller : public ir::IRMutator { - //! Collect the write info in \p op. - void Collect(const Expr* op) { Visit(op, op); } - - bool IsWrite(const std::string& tensor_name) const { - return tensor_written.count(tensor_name); - } - - private: - std::set tensor_written; - - void Visit(const Expr* expr, const Expr* op) override { - IRMutator::Visit(expr, op); - } - - void Visit(const ir::Store* expr, const Expr* op) override { - auto* node = op->As(); - CHECK(node); - auto* tensor = node->tensor.As(); - CHECK(tensor); - tensor_written.insert(tensor->name); - IRMutator::Visit(expr, op); - } - - void Visit(const ir::_Tensor_* op, const Expr* expr) override { - auto* node = expr->As(); - if (node->is_call_node()) { - tensor_written.insert(node->name); - } - } -}; - -} // namespace optim -} // namespace cinn diff --git a/paddle/cinn/optim/vectorize_loops.cc b/paddle/cinn/optim/vectorize_loops.cc index 745bec47b4507..2f3a9b29a3567 100644 --- a/paddle/cinn/optim/vectorize_loops.cc +++ b/paddle/cinn/optim/vectorize_loops.cc @@ -31,7 +31,6 @@ #include "paddle/cinn/ir/utils/ir_printer.h" #include "paddle/cinn/optim/ir_replace.h" #include "paddle/cinn/optim/ir_simplify.h" -#include "paddle/cinn/optim/tensor_write_tell.h" #include "paddle/cinn/optim/unroll_loops.h" #include "paddle/cinn/utils/functional.h" @@ -185,7 +184,7 @@ class CudaVectorizer : public IRMutator { const Var iter_var_; // the loop var of the vecotrized loop const int factor_; // the factor for vectorize - TensorWriteTeller write_teller_; + std::set write_teller_; TensorVectorizeTeller vectorized_teller_; absl::flat_hash_map tensor2vectorized_vars_; @@ -215,7 +214,7 @@ class CudaVectorizer : public IRMutator { } void Visit(Expr *expr) { - write_teller_.Collect(expr); + write_teller_ = ir::CollectTensorNeedsWrite(expr); vectorized_teller_.Collect(expr); IRMutator::Visit(expr, expr); } @@ -289,7 +288,7 @@ class CudaVectorizer : public IRMutator { const std::vector &indices, bool is_store) { auto *node = tensor.As(); - bool is_const = !write_teller_.IsWrite(node->name); + bool is_const = !write_teller_.count(node->name); // generate the corresponding vector type Type scalar_type = tensor->type().ElementOf(); diff --git a/paddle/cinn/runtime/cuda/float16.h b/paddle/cinn/runtime/cuda/float16.h index cae59186dc832..d64731387d596 100644 --- a/paddle/cinn/runtime/cuda/float16.h +++ b/paddle/cinn/runtime/cuda/float16.h @@ -597,9 +597,9 @@ __host__ __device__ inline bool(isfinite)(const float16& a) { __host__ __device__ inline float16(abs)(const float16& a) { #if defined(CINN_CUDA_FP16) && (defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 530) - return float16(__habs(a.to_half())); + return static_cast(__habs(a.to_half())); #else - return float16(fabsf(static_cast(a))); + return static_cast(fabsf(static_cast(a))); #endif } diff --git a/paddle/cinn/utils/attribute_util.h b/paddle/cinn/utils/attribute_util.h index aaffed7085c7b..17c1471c38c2d 100644 --- a/paddle/cinn/utils/attribute_util.h +++ b/paddle/cinn/utils/attribute_util.h @@ -18,29 +18,29 @@ #include "paddle/cinn/common/type.h" #include "paddle/cinn/utils/type_defs.h" -#include "paddle/fluid/ir/dialect/paddle_dialect/ir/pd_attribute.h" -#include "paddle/ir/core/builtin_type.h" +#include "paddle/fluid/pir/dialect/operator/ir/op_attribute.h" #include "paddle/phi/common/data_type.h" +#include "paddle/pir/core/builtin_type.h" namespace cinn { namespace utils { -using NewIR_AttributeMap = std::unordered_map; +using NewIR_AttributeMap = std::unordered_map; -Attribute ConvertAttribute(const ::ir::Attribute& src_attr) { +Attribute ConvertAttribute(const ::pir::Attribute& src_attr) { Attribute dst_attr; - if (src_attr.isa<::ir::BoolAttribute>()) { - dst_attr = src_attr.dyn_cast<::ir::BoolAttribute>().data(); - } else if (src_attr.isa<::ir::FloatAttribute>()) { - dst_attr = src_attr.dyn_cast<::ir::FloatAttribute>().data(); - } else if (src_attr.isa<::ir::Int32Attribute>()) { - dst_attr = src_attr.dyn_cast<::ir::Int32Attribute>().data(); - } else if (src_attr.isa<::ir::StrAttribute>()) { - dst_attr = src_attr.dyn_cast<::ir::StrAttribute>().AsString(); - } else if (src_attr.isa<::ir::Int64Attribute>()) { - dst_attr = src_attr.dyn_cast<::ir::Int64Attribute>().data(); - } else if (src_attr.isa<::ir::DoubleAttribute>()) { - dst_attr = src_attr.dyn_cast<::ir::DoubleAttribute>().data(); + if (src_attr.isa<::pir::BoolAttribute>()) { + dst_attr = src_attr.dyn_cast<::pir::BoolAttribute>().data(); + } else if (src_attr.isa<::pir::FloatAttribute>()) { + dst_attr = src_attr.dyn_cast<::pir::FloatAttribute>().data(); + } else if (src_attr.isa<::pir::Int32Attribute>()) { + dst_attr = src_attr.dyn_cast<::pir::Int32Attribute>().data(); + } else if (src_attr.isa<::pir::StrAttribute>()) { + dst_attr = src_attr.dyn_cast<::pir::StrAttribute>().AsString(); + } else if (src_attr.isa<::pir::Int64Attribute>()) { + dst_attr = src_attr.dyn_cast<::pir::Int64Attribute>().data(); + } else if (src_attr.isa<::pir::DoubleAttribute>()) { + dst_attr = src_attr.dyn_cast<::pir::DoubleAttribute>().data(); } else if (src_attr.isa()) { auto& arr = src_attr.dyn_cast() .data() @@ -75,10 +75,10 @@ AttributeMap ConvertAttributes(const NewIR_AttributeMap& src_attrs) { } #define CASE_TYPE(src, dst) \ - else if (type.isa<::ir::src>()) return common::dst(); + else if (type.isa<::pir::src>()) return common::dst(); -common::Type ConvertIRType(::ir::Type type) { - if (type.isa<::ir::BFloat16Type>()) return common::BF16(); +common::Type ConvertIRType(::pir::Type type) { + if (type.isa<::pir::BFloat16Type>()) return common::BF16(); CASE_TYPE(Float16Type, F16) CASE_TYPE(Float32Type, F32) CASE_TYPE(Float64Type, F64) diff --git a/paddle/fluid/CMakeLists.txt b/paddle/fluid/CMakeLists.txt index 628bf6d00c11c..c8e35ad43a36b 100644 --- a/paddle/fluid/CMakeLists.txt +++ b/paddle/fluid/CMakeLists.txt @@ -8,7 +8,7 @@ add_subdirectory(pybind) add_subdirectory(eager) add_subdirectory(prim) add_subdirectory(jit) -add_subdirectory(ir) +add_subdirectory(pir) add_subdirectory(ir_adaptor) add_subdirectory(primitive) # NOTE: please add subdirectory inference at last. diff --git a/paddle/fluid/distributed/auto_parallel/spmd_rules/layer_norm_spmd_rule.cc b/paddle/fluid/distributed/auto_parallel/spmd_rules/layer_norm_spmd_rule.cc index 61b2c6bb91c46..81f25a8d6ed88 100644 --- a/paddle/fluid/distributed/auto_parallel/spmd_rules/layer_norm_spmd_rule.cc +++ b/paddle/fluid/distributed/auto_parallel/spmd_rules/layer_norm_spmd_rule.cc @@ -63,7 +63,6 @@ LayerNormSPMDRule::InferForward(const std::vector& input_specs, int begin_norm_axis = ExtractAttr("begin_norm_axis", attrs); - // Step2.3.2 handle input tensor partial (TODO) VLOG(4) << "LayerNormSPMDRule InferForward Inputs: " << "x shape: [" << str_join(x_shape) << "], x_dims_mapping: [" << str_join(x_dims_mapping) << "]; scale shape: [" @@ -74,9 +73,9 @@ LayerNormSPMDRule::InferForward(const std::vector& input_specs, << begin_norm_axis << "]; "; // step1: build Einsum Notation - // ijk,k,k->ijk,x,x (x,scale,bias->out,mean,variance, begin_norm_axis=2, x=ij) - // ijkl,y(kl),y(kl)->ijkl,x(ij),x(ij) (x,scale,bias->out,mean,variance, - // begin_norm_axis=2, x=ij, y=kl) + // ijk,k,k->ijk,z,z (x,scale,bias->out,mean,variance, begin_norm_axis=2, z=ij) + // ijkl,y(kl),y(kl)->ijkl,z(ij),z(ij) (x,scale,bias->out,mean,variance, + // begin_norm_axis=2, z=ij, y=kl) std::string x_axes = ""; for (auto i = 0; i < x_ndim; ++i) { x_axes += static_cast(static_cast('k') - begin_norm_axis + i); @@ -124,15 +123,20 @@ LayerNormSPMDRule::InferForward(const std::vector& input_specs, out_dims_mapping.reserve(out_axes.size()); int64_t mean_shard_dim = -1; - for (size_t i = 0; i < out_axes.size(); ++i) { - if (i < static_cast(begin_norm_axis)) { - out_dims_mapping.push_back(x_dims_mapping[i]); - // if ijk,k,k->ijk,x,x (x,scale,bias->out,mean,variance, - // begin_norm_axis=2, x=ij), and the dims_mapping of input is (0,1,-1), + // As the mean and variance in outputs are `flattened` from + // x[0:begin_norm_axis], only the first axis can be sharded, + // the axes 1 to begin_norm_axis-1 are set to be replicated. + std::vector x_dims_mapping_dst(x_ndim, -1); + x_dims_mapping_dst[0] = x_dims_mapping[0]; + for (int i = 0; i < x_ndim; ++i) { + if (i < begin_norm_axis) { + out_dims_mapping.push_back(x_dims_mapping_dst[i]); + // if ijk,k,k->ijk,z,z (x,scale,bias->out,mean,variance, + // begin_norm_axis=2, z=ij), and the dims_mapping of input is (0,1,-1), // the mean and varience is sharded by dim 0 and 1, // which is not supported currently. - mean_shard_dim = - ShardingMergeForAxis(mean_axes, mean_shard_dim, x_dims_mapping[i]); + mean_shard_dim = ShardingMergeForAxis( + mean_axes, mean_shard_dim, x_dims_mapping_dst[i]); } else { out_dims_mapping.push_back(-1); } @@ -142,7 +146,7 @@ LayerNormSPMDRule::InferForward(const std::vector& input_specs, varience_dist_attr_dst.set_dims_mapping({mean_shard_dim}); // step2.3: Merge and get Inputs' New Dims Mapping. - x_dist_attr_dst.set_dims_mapping(out_dims_mapping); + x_dist_attr_dst.set_dims_mapping(x_dims_mapping_dst); input_dist_attrs.emplace_back(x_dist_attr_dst); // TODO(zhiqiu): support shardding on scale and bias // Now, apply replicating. @@ -173,12 +177,102 @@ LayerNormSPMDRule::InferForward(const std::vector& input_specs, std::pair, std::vector> LayerNormSPMDRule::InferBackward( + const std::vector& input_specs, const std::vector& output_specs, const paddle::framework::AttributeMap& attrs) { - PADDLE_THROW(phi::errors::Unimplemented( - "InferBackward of LayerNormSPMDRule is NOT implemented yet.")); + // step0: verify input args based on layer_norm logic + int64_t ninputs = input_specs.size(); + int64_t noutputs = output_specs.size(); + PADDLE_ENFORCE_EQ( + ninputs, + 3, + phi::errors::InvalidArgument( + "The size of InputSpec of layer_norm should be 3, but got [%d].", + ninputs)); + PADDLE_ENFORCE_EQ( + noutputs, + 3, + phi::errors::InvalidArgument( + "The size of InputSpec of layer_norm should be 3, but got [%d].", + noutputs)); + VerifySpecs(output_specs, "layer_norm_backward"); + + // step1: build Einsum Notation + // ijk,k,k->ijk,z,z (x,scale,bias->out,mean,variance, begin_norm_axis=2, z=ij) + // ijkl,y(kl),y(kl)->ijkl,z(ij),z(ij) (x,scale,bias->out,mean,variance, + // begin_norm_axis=2, z=ij, y=kl) + int begin_norm_axis = ExtractAttr("begin_norm_axis", attrs); + std::string alphabet = "ijklmnopqrstuvwxyz"; + int x_ndim = input_specs[0].shape().size(); + std::string x_axes = alphabet.substr(0, x_ndim); + // the axes after norm_axis should be replicated, + // so set their notation to '1'. + for (int i = 1; i < x_ndim; i++) { + x_axes[i] = '1'; + } + std::string out_axes = x_axes; + std::string mean_axes(1, '1'), varience_axes(1, '1'); + if (begin_norm_axis > 0) { + mean_axes[0] = out_axes[0]; + varience_axes[0] = out_axes[0]; + } + std::vector output_axes_vec; + output_axes_vec.emplace_back(out_axes); + output_axes_vec.emplace_back(mean_axes); + output_axes_vec.emplace_back(varience_axes); + + // step2: Sharding Propogation + // For the axes after norm_axis in both input and output tensors, + // set their dims mappings to -1. For the other axes, set input + // tensor's dims mapping the same as output tensor's dims mapping. + // step2.1 merge dims mappings of output, mean, variance. + std::vector>> axes_sharding_info; + axes_sharding_info = GetAxesDimsMappingPair(output_axes_vec, output_specs); + std::unordered_map axis_to_dim_map = + ShardingMergeForTensors(axes_sharding_info); + + // step2.2 infer input dims mapping + std::vector input_dims_mapping = + GetDimsMappingForAxes(x_axes, axis_to_dim_map); + std::vector input_dist_attrs; + for (int64_t i = 0; i < ninputs; i++) { + input_dist_attrs.emplace_back(input_specs[i].dist_attr()); + } + input_dist_attrs[0].set_dims_mapping(input_dims_mapping); + // set bias and scale to be replicated + input_dist_attrs[1].set_dims_mapping({-1}); + input_dist_attrs[2].set_dims_mapping({-1}); + + // step2.3 update output dims mappings with merged one + std::vector output_dist_attrs; + for (int64_t i = 0; i < noutputs; i++) { + output_dist_attrs.emplace_back(output_specs[i].dist_attr()); + output_dist_attrs[i].set_dims_mapping( + GetDimsMappingForAxes(output_axes_vec[i], axis_to_dim_map)); + } + + VLOG(4) << "LayerNormSPMDRule InferBackward:"; + VLOG(4) << "begin_norm_axis: " << begin_norm_axis; + for (int64_t i = 0; i < noutputs; i++) { + VLOG(4) << "Output" << std::to_string(i) << " shape: [" + << str_join(output_specs[i].shape()) << "] " + << "Einsum Notation: " << output_axes_vec[i] + << " src_dims_mapping: [" + << str_join(output_specs[i].dims_mapping()) << "] " + << "dst_dims_mapping: [" + << str_join(output_dist_attrs[i].dims_mapping()) << "]"; + } + + for (int64_t i = 0; i < ninputs; i++) { + VLOG(4) << "Input" << std::to_string(i) << " shape: [" + << str_join(input_specs[i].shape()) << "] " + << "Einsum Notation: " << std::string(i == 0 ? x_axes : "1") + << " dims_mapping: [" + << str_join(input_dist_attrs[i].dims_mapping()) << "]"; + } + VLOG(4) << std::endl; - return {}; + return {input_dist_attrs, output_dist_attrs}; } } // namespace auto_parallel diff --git a/paddle/fluid/distributed/auto_parallel/spmd_rules/layer_norm_spmd_rule.h b/paddle/fluid/distributed/auto_parallel/spmd_rules/layer_norm_spmd_rule.h index b3bd6b6b18faf..da40f3da5653f 100644 --- a/paddle/fluid/distributed/auto_parallel/spmd_rules/layer_norm_spmd_rule.h +++ b/paddle/fluid/distributed/auto_parallel/spmd_rules/layer_norm_spmd_rule.h @@ -32,7 +32,8 @@ class LayerNormSPMDRule : public SPMDRuleBase { const paddle::framework::AttributeMap& attrs) override; std::pair, std::vector> - InferBackward(const std::vector& output_specs, + InferBackward(const std::vector& input_specs, + const std::vector& output_specs, const paddle::framework::AttributeMap& attrs) override; }; } // namespace auto_parallel diff --git a/paddle/fluid/distributed/auto_parallel/spmd_rules/split_spmd_rule.cc b/paddle/fluid/distributed/auto_parallel/spmd_rules/split_spmd_rule.cc index 6f50c17fc5c2b..51b4f4b10c675 100644 --- a/paddle/fluid/distributed/auto_parallel/spmd_rules/split_spmd_rule.cc +++ b/paddle/fluid/distributed/auto_parallel/spmd_rules/split_spmd_rule.cc @@ -73,7 +73,7 @@ SplitSPMDRule::InferForward(const std::vector& input_specs, std::unordered_map axis_to_dim_map = ShardingMergeForTensors(axes_sharding_info); - // step2.2: infer output dimsmapping from merged input dimsmapping + // step2.2: infer output dims mapping from merged input dims mapping std::vector output_dims_mapping = GetDimsMappingForAxes(output_axes, axis_to_dim_map); @@ -94,7 +94,7 @@ SplitSPMDRule::InferForward(const std::vector& input_specs, new_input_dims_mapping[axis] = -1; new_input_dist_attrs[0].set_dims_mapping(new_input_dims_mapping); - // Step2.4 handle input tensor partial (TODO) + // Step3 Handle input tensor partial (TODO) VLOG(4) << "SplitSPMDRule InferForward: "; for (int64_t i = 0; i < ninputs; i++) { VLOG(4) << "Input" << std::to_string(i) << " shape: [" @@ -113,12 +113,104 @@ SplitSPMDRule::InferForward(const std::vector& input_specs, } std::pair, std::vector> -SplitSPMDRule::InferBackward(const std::vector& output_specs, +SplitSPMDRule::InferBackward(const std::vector& input_specs, + const std::vector& output_specs, const paddle::framework::AttributeMap& attrs) { - PADDLE_THROW(phi::errors::Unimplemented( - "InferBackward of SplitPMDRule is NOT implemented yet.")); + // step0: Verify Input Args Based on Elementwise Logic + int64_t ninputs = input_specs.size(); + int64_t noutputs = output_specs.size(); + PADDLE_ENFORCE_EQ( + ninputs, + 1, + phi::errors::InvalidArgument("The size of InputSpec in split must " + "be equal to 1, but got [%d].", + ninputs)); + VerifySpecs(output_specs, "split"); + + // check whether the size of output_specs equals + // to the specified split num in op attributes + int64_t specified_split_num = -1; + // split api uses num or sections as attribute + if (attrs.find("num") != attrs.end()) { + specified_split_num = ExtractAttr("num", attrs); + } else if (attrs.find("sections") != attrs.end()) { + std::vector sections = + ExtractAttr>("sections", attrs); + specified_split_num = sections.size(); + } + PADDLE_ENFORCE_EQ( + noutputs, + specified_split_num, + phi::errors::InvalidArgument("The size of OutputSpec [%d] is not equal " + "to the specified split number [%d]", + noutputs, + specified_split_num)); + + // step1: Build Einsum Notation + int64_t ndim = input_specs[0].shape().size(); + int64_t axis = ExtractAttr("axis", attrs); + if (axis < 0) { + axis += ndim; + } + std::string alphabet = "abcdefghijlmnopqrstuvwxyz"; + + // get einsum notation for input, use a special + // notation 'k' to mark the splitted axis in input + std::string input_axes = alphabet.substr(0, ndim); + input_axes[axis] = 'k'; + + // get einsum notation for output + std::string output_axes(input_axes); + output_axes[axis] = 'k'; + + // step2: Sharding Propogation + // step2.1: merge input shardings + std::vector output_axes_vec; + for (int64_t i = 0; i < noutputs; i++) { + output_axes_vec.emplace_back(output_axes); + } + std::vector>> axes_sharding_info; + axes_sharding_info = GetAxesDimsMappingPair(output_axes_vec, output_specs); + std::unordered_map axis_to_dim_map = + ShardingMergeForTensors(axes_sharding_info); + + // step2.2: infer input dims mapping from output dims mapping + // the split axis in input is set to -1. + std::vector input_dims_mapping = + GetDimsMappingForAxes(input_axes, axis_to_dim_map, true); + input_dims_mapping[axis] = -1; + TensorDistAttr input_dist_attr(input_specs[0].dist_attr()); + input_dist_attr.set_dims_mapping(input_dims_mapping); + + // step2.3 get new dist attribute for output. the splitted + // cannot be sharded, if it is sharded, set it to replicated. + std::vector output_dist_attrs; + for (int64_t i = 0; i < noutputs; i++) { + output_dist_attrs.emplace_back(output_specs[i].dist_attr()); + std::vector out_dims_mapping = + GetDimsMappingForAxes(output_axes, axis_to_dim_map, true); + out_dims_mapping[axis] = -1; + output_dist_attrs[i].set_dims_mapping(out_dims_mapping); + } + + // step3 Handle input tensor partial (TODO) + + VLOG(4) << "SplitSPMDRule InferBackward: "; + for (int64_t i = 0; i < noutputs; i++) { + VLOG(4) << "Output" << std::to_string(i) << " shape: [" + << str_join(output_specs[i].shape()) << "] " + << "einsum_notation: " << output_axes << " dims_mapping: [" + << str_join(output_specs[i].dims_mapping()) << "]"; + } + for (int64_t i = 0; i < ninputs; i++) { + VLOG(4) << "Input" << std::to_string(i) << " shape: [" + << str_join(input_specs[i].shape()) << "] " + << "einsum_notation: " << input_axes << " dims_mapping: [" + << str_join(input_dims_mapping) << "]"; + } + VLOG(4) << std::endl; - return {}; + return {{input_dist_attr}, output_dist_attrs}; } } // namespace auto_parallel diff --git a/paddle/fluid/distributed/auto_parallel/spmd_rules/split_spmd_rule.h b/paddle/fluid/distributed/auto_parallel/spmd_rules/split_spmd_rule.h index f974e4cccce05..f8a1300e62409 100644 --- a/paddle/fluid/distributed/auto_parallel/spmd_rules/split_spmd_rule.h +++ b/paddle/fluid/distributed/auto_parallel/spmd_rules/split_spmd_rule.h @@ -32,7 +32,8 @@ class SplitSPMDRule : public SPMDRuleBase { const paddle::framework::AttributeMap& attrs) override; std::pair, std::vector> - InferBackward(const std::vector& output_specs, + InferBackward(const std::vector& input_specs, + const std::vector& output_specs, const paddle::framework::AttributeMap& attrs) override; }; } // namespace auto_parallel diff --git a/paddle/fluid/distributed/fleet_executor/carrier.cc b/paddle/fluid/distributed/fleet_executor/carrier.cc index 70ab3b94de3c5..6dc25faa80b4b 100644 --- a/paddle/fluid/distributed/fleet_executor/carrier.cc +++ b/paddle/fluid/distributed/fleet_executor/carrier.cc @@ -284,6 +284,14 @@ static std::shared_ptr GetGC( max_memory_size)); } } +#endif +#ifdef PADDLE_WITH_CUSTOM_DEVICE + if (platform::is_custom_place(place)) { + if (framework::IsFastEagerDeletionModeEnabled()) { + gc.reset(new framework::CustomDeviceUnsafeFastGarbageCollector( + place, max_memory_size)); + } + } #endif } // max_memory_size >= 0 diff --git a/paddle/fluid/distributed/fleet_executor/interceptor.h b/paddle/fluid/distributed/fleet_executor/interceptor.h index 7c9cf9c8112ef..7645abf24cfd3 100644 --- a/paddle/fluid/distributed/fleet_executor/interceptor.h +++ b/paddle/fluid/distributed/fleet_executor/interceptor.h @@ -29,8 +29,8 @@ #include "paddle/fluid/platform/errors.h" #include "paddle/fluid/platform/macros.h" #include "paddle/fluid/platform/place.h" -#include "paddle/ir/core/program.h" -#include "paddle/ir/core/value.h" +#include "paddle/pir/core/program.h" +#include "paddle/pir/core/value.h" namespace paddle { namespace framework { diff --git a/paddle/fluid/eager/auto_code_generator/CMakeLists.txt b/paddle/fluid/eager/auto_code_generator/CMakeLists.txt index e27310dea5629..25d2f4dacfd16 100644 --- a/paddle/fluid/eager/auto_code_generator/CMakeLists.txt +++ b/paddle/fluid/eager/auto_code_generator/CMakeLists.txt @@ -65,7 +65,7 @@ if(WIN32) add_custom_command( OUTPUT ${eager_generator_path}/ir.dll COMMAND ${CMAKE_COMMAND} -E copy ${IR_LIB} ${eager_generator_path} - DEPENDS ir) + DEPENDS pir) list(APPEND EAGER_CODEGEN_DEPS ${eager_generator_path}/ir.dll) endif() diff --git a/paddle/fluid/eager/auto_code_generator/generator/eager_gen.py b/paddle/fluid/eager/auto_code_generator/generator/eager_gen.py index da4a9aab53870..519e50b9175cc 100644 --- a/paddle/fluid/eager/auto_code_generator/generator/eager_gen.py +++ b/paddle/fluid/eager/auto_code_generator/generator/eager_gen.py @@ -81,6 +81,30 @@ "matmul_grad": {"x": "grad_y", "y": "grad_x"}, } +strided_op_list = { + "as_complex", + "as_real", + "as_strided", + "real", + "imag", + "diagonal", + "flatten", + "flatten_infer", + "reshape", + "slice", + "squeeze_infer", + "squeeze", + "strided_slice", + "strided_slice_raw", + "tensor_unfold", + "transpose", + "unbind", + "unsqueeze_infer", + "unsqueeze", + "view_shape", + "view_dtype", +} + ######### # Utils # @@ -234,6 +258,9 @@ class {} : public egr::GradNodeBase {{ // Node Declaration std::shared_ptr<{}> grad_node; + // Pre contiguous tensor in not strided op, if 1)require_any_grad=true; 2) need wrapper to backward; 3) not contiguous +{} + // Set grad_node before API Call {} @@ -380,6 +407,7 @@ class {} : public egr::GradNodeBase {{ #include "paddle/fluid/prim/api/all.h" #include "paddle/fluid/prim/utils/utils.h" #include "paddle/phi/core/flags.h" +#include "paddle/phi/api/lib/data_transform.h" PHI_DECLARE_bool(check_nan_inf); {} """ @@ -408,6 +436,7 @@ class {} : public egr::GradNodeBase {{ #include "paddle/fluid/eager/nan_inf_utils.h" #include "paddle/fluid/eager/api/manual/eager_manual/dygraph_forward_api.h" #include "paddle/phi/core/flags.h" +#include "paddle/phi/api/lib/data_transform.h" PHI_DECLARE_bool(check_nan_inf); PHI_DECLARE_string(tensor_operants_mode); @@ -505,6 +534,12 @@ class {} : public egr::GradNodeBase {{ if ({}.initialized()) {{ VLOG(10) << {}.name() << "({}) use_count: " << {}.impl().use_count(); if ({}.impl().use_count() == 1 || ({}.impl().use_count() == 2 && {}.impl().get() == {}.impl().get())) {{ + if ({}.is_dense_tensor() && !std::dynamic_pointer_cast({}.impl())->meta().is_contiguous()) {{ + auto tmp = paddle::experimental::Trans2Contiguous(*(std::dynamic_pointer_cast({}.impl()))); + auto holder = tmp.MoveMemoryHolder(); + std::dynamic_pointer_cast({}.impl())->ResetHolder(holder); + std::dynamic_pointer_cast({}.impl())->set_meta(tmp.meta()); + }} can_be_inplaced = true; }} }}""" @@ -977,6 +1012,7 @@ def GenerateNodeCreationCodes(self, for_backward=False, is_inplaced=False): set_attributes_list.append(set_attributes) set_attributes_str = "\n".join(set_attributes_list) + need_pre_contiguous_set = set() # SetTensorWrappers set_input_tensor_wrappers_list = [] set_output_tensor_wrappers_list = [] @@ -1000,12 +1036,30 @@ def GenerateNodeCreationCodes(self, for_backward=False, is_inplaced=False): {"indent": indent, "name": name} ) else: - set_tensor_wrappers = f"{indent}if({name}) grad_node->SetTensorWrapper{name}(*{name});" + if ( + (forward_api_name in strided_op_list) + or for_backward + or IsVectorTensorType(atype) + or (name in self.optional_inputs) + ): + set_tensor_wrappers = f"{indent}if({name}) grad_node->SetTensorWrapper{name}(*{name});" + else: + need_pre_contiguous_set.add(name) + set_tensor_wrappers = f"{indent}if({name}) grad_node->SetTensorWrapper{name}(*{name}_tmp);" else: if is_inplace_input: set_tensor_wrappers = f"{indent}auto {name}_clone = paddle::experimental::assign({name});\n{indent}grad_node->SetTensorWrapper{name}({name}_clone);" else: - set_tensor_wrappers = f"{indent}grad_node->SetTensorWrapper{name}({name});" + if ( + (forward_api_name in strided_op_list) + or for_backward + or IsVectorTensorType(atype) + or (name in self.optional_inputs) + ): + set_tensor_wrappers = f"{indent}grad_node->SetTensorWrapper{name}({name});" + else: + need_pre_contiguous_set.add(name) + set_tensor_wrappers = f"{indent}grad_node->SetTensorWrapper{name}({name}_tmp);" set_input_tensor_wrappers_list.append(set_tensor_wrappers) else: # Forwad's output as backward's input if num_fwd_outputs > 1: @@ -1025,6 +1079,24 @@ def GenerateNodeCreationCodes(self, for_backward=False, is_inplaced=False): set_output_tensor_wrappers_list ) + if (forward_api_name in strided_op_list) or for_backward: + self.inputs_call_list_tmp = None + self.node_creation_pre_contiguous_str = "" + else: + self.inputs_call_list_tmp = self.inputs_call_list + pre_contiguous_list = [] + for name, (ttype, pos) in forward_inputs_position_map.items(): + if name in need_pre_contiguous_set: + pre_contiguous_list.append( + f"{indent}const auto& {name}_tmp = (require_any_grad && {name}.is_dense_tensor() && !std::dynamic_pointer_cast({name}.impl())->meta().is_contiguous()) ? paddle::Tensor(std::make_shared(std::move(paddle::experimental::Trans2Contiguous(*(std::dynamic_pointer_cast({name}.impl())))))) : {name};" + ) + self.inputs_call_list_tmp[pos] = ( + self.inputs_call_list_tmp[pos] + '_tmp' + ) + self.node_creation_pre_contiguous_str = "\n".join( + pre_contiguous_list + ) + # SetGradOutMeta & SetEdges grad_node_out_list = [] set_grad_out_meta_list = [] @@ -1463,6 +1535,7 @@ def GenerateForwardDefinitionAndDeclaration(self, is_inplaced): inputs_args_declaration_str = ", ".join(inputs_args_declaration_list) inputs_args_definition_str = ", ".join(inputs_args_definition_list) inputs_call_args_str = ", ".join(inputs_call_list) + self.inputs_call_list = inputs_call_list # Forward Full Logic function_name = forward_api_name @@ -1649,6 +1722,12 @@ def GenerateForwardDefinitionAndDeclaration(self, is_inplaced): node_creation_str = self.node_creation_str node_creation_before_call_str = self.node_creation_before_call_str node_creation_after_call_str = self.node_creation_after_call_str + node_creation_pre_contiguous_str = ( + self.node_creation_pre_contiguous_str + ) + if self.inputs_call_list_tmp is not None: + inputs_call_args_str_tmp = ", ".join(self.inputs_call_list_tmp) + forward_call_str = f"{indent}{api_out_type} api_result = paddle::experimental::{namespace}{function_name}({inputs_call_args_str_tmp});" dygraph_event_str = f"{indent}paddle::platform::RecordEvent dygraph_entrance_record_event(\"{forward_api_name} dygraph\", paddle::platform::TracerEventType::Operator, 1);\n" forward_ad_function_name = GetDygraphForwardFunctionName( @@ -1760,6 +1839,7 @@ def GenerateForwardDefinitionAndDeclaration(self, is_inplaced): before_log_str, compute_require_grad_args_str, self.grad_node_name, + node_creation_pre_contiguous_str, node_creation_before_call_str, forward_call_str, check_nan_inf_str, @@ -2160,6 +2240,11 @@ def GenerateNodeDefinition( transformed_tensor_name, transformed_tensor_name, tensor_wrapper_intermidiate_tensor_str, + transformed_tensor_name, + transformed_tensor_name, + transformed_tensor_name, + transformed_tensor_name, + transformed_tensor_name, ) inplace_grad_input_str = transformed_tensor_name if is_optional: @@ -2229,6 +2314,11 @@ def GenerateNodeDefinition( transformed_tensor_name, transformed_tensor_name, grads_tensor_str, + transformed_tensor_name, + transformed_tensor_name, + transformed_tensor_name, + transformed_tensor_name, + transformed_tensor_name, ) inplace_grad_input_str = transformed_tensor_name diff --git a/paddle/fluid/eager/custom_operator/custom_operator_node.cc b/paddle/fluid/eager/custom_operator/custom_operator_node.cc index 9532d3181ac73..cdd1f7bfbe945 100644 --- a/paddle/fluid/eager/custom_operator/custom_operator_node.cc +++ b/paddle/fluid/eager/custom_operator/custom_operator_node.cc @@ -18,6 +18,7 @@ #include "paddle/fluid/framework/custom_operator_utils.h" #include "paddle/fluid/platform/profiler/event_tracing.h" #include "paddle/phi/api/ext/op_meta_info.h" +#include "paddle/phi/api/lib/data_transform.h" #include "paddle/phi/core/dense_tensor.h" namespace egr { @@ -201,7 +202,18 @@ RunCustomOpNode::operator()(paddle::small_vector, } VLOG(6) << "Prepare Grad inputs"; - for (const auto& in : tmp_ins) { + for (auto& in : tmp_ins) { + for (auto& tensor : in) { + if (tensor.initialized() && tensor.is_dense_tensor() && + !std::dynamic_pointer_cast(tensor.impl()) + ->meta() + .is_contiguous()) { + tensor.set_impl(std::make_shared( + std::move(paddle::experimental::Trans2Contiguous(*( + std::dynamic_pointer_cast(tensor.impl())))))); + } + } + ctx.EmplaceBackInputs(in); } VLOG(6) << "Prepare Grad attrs"; diff --git a/paddle/fluid/eager/to_static/run_program_op_func.h b/paddle/fluid/eager/to_static/run_program_op_func.h index 8f6e6f4028c1d..2a3304fffe63c 100644 --- a/paddle/fluid/eager/to_static/run_program_op_func.h +++ b/paddle/fluid/eager/to_static/run_program_op_func.h @@ -98,6 +98,7 @@ inline void run_program_ad_func( std::vector& dout, // NOLINT const paddle::framework::AttributeMap& attrs) { // Prepare Autograd Meta + VLOG(2) << "start run run_program ad function."; auto deref_out = details::DereferenceTensors(out); std::vector p_autograd_x = egr::EagerUtils::nullable_autograd_meta(x); @@ -174,3 +175,107 @@ inline void run_program_ad_func( egr::EagerUtils::SetHistory(&p_autograd_outs, grad_node); } } + +inline void newir_run_program_ad_func( + const std::vector& x, + const std::vector& params, + std::vector& out, // NOLINT + std::vector& step_scope, // NOLINT + std::vector& dout, // NOLINT + const paddle::framework::AttributeMap& attrs) { + // Prepare Autograd Meta + VLOG(2) << "start run newir run_program ad function."; + auto deref_out = details::DereferenceTensors(out); + std::vector p_autograd_x = + egr::EagerUtils::nullable_autograd_meta(x); + std::vector p_autograd_params = + egr::EagerUtils::nullable_autograd_meta(params); + std::vector p_autograd_outs = + egr::EagerUtils::nullable_autograd_meta(deref_out); + + bool trace_backward = egr::Controller::Instance().HasGrad(); + bool require_any_grad = egr::EagerUtils::ComputeRequireGrad( + trace_backward, &p_autograd_x, &p_autograd_params); + + // Create Middle Output for GradNode. + auto middle_size = + PADDLE_GET_CONST(std::vector<::pir::Value>, attrs.at("fm")).size(); + auto output_size = + PADDLE_GET_CONST(std::vector<::pir::Value>, attrs.at("fo")).size(); + auto middles = std::vector(); + std::shared_ptr grad_node; + VLOG(2) << "start run run_program with require_any_grad = " + << require_any_grad; + + if (require_any_grad) { + // Create GradOpNode (1 means [out_grad], 2 means [x_grad, paramx_grad]) + grad_node = std::make_shared(1, 2); + grad_node->GetMiddle().resize(middle_size); + grad_node->GetOutputs().resize(output_size); + for (size_t i = 0; i < middle_size; ++i) { + grad_node->GetMiddle()[i] = + paddle::Tensor(std::make_shared()); + middles.push_back(&grad_node->GetMiddle()[i]); + } + for (size_t i = 0; i < output_size; ++i) { + grad_node->GetOutputs()[i] = *out[i]; + } + } + + // Call forward function + // if require_any_grad is False, don't save any middle vars. + NewIRRunProgramAPI( + x, params, out, middles, step_scope, dout, require_any_grad, attrs); + if (require_any_grad) { + // auto x_names = + // PADDLE_GET_CONST(std::vector, attrs.at("x_names")); + + egr::EagerUtils::PassStopGradient(false, &p_autograd_outs); + + // Set Attributes + grad_node->SetAttrMap(attrs); + + // auto* forward_global_block = PADDLE_GET_CONST( + // paddle::framework::BlockDesc*, attrs.at("forward_global_block")); + // auto* backward_global_block = PADDLE_GET_CONST( + // paddle::framework::BlockDesc*, attrs.at("backward_global_block")); + // Clear unused x vars + // auto filter_x = + // filter_unused_input_var_in_backward(x, x_names, backward_global_block); + // Set TensorWrappers + grad_node->SetFwdX(x); + // Clear unused out vars + // clear_unused_out_var_in_backward(out, backward_global_block, + // step_scope[0]); + + grad_node->SetFwdParams(params); + grad_node->SetStepScope(step_scope); // just for set useable. + + // Set Grad out rank as same as fwd input and set stop gradient to bwd + // NOTE(@xiongkun): Not every tensor in x(list of tensor) is required + // gradient. for example: x[1] is not used for output, the x[1] is ignored. + + // TODO(@xiongkun): rewrite by new ir representation. + std::vector x_require_grad; + for (size_t i = 0; i < x.size(); ++i) { + x_require_grad.push_back(&x[i]); + } + + grad_node->SetGradOutMeta(x_require_grad, /*slot id*/ 0); + grad_node->SetGradOutMeta(params, /*slot id*/ 1); + + // VLOG(2) << "clear_no_grad_edges."; + // clear_no_grad_edges_with_partial_block(params, + // forward_global_block, + // backward_global_block, + // grad_node.get(), + // [>slot id<] 1); + + grad_node->SetGradInMeta(deref_out, 0); + + egr::EagerUtils::SetOutRankWithSlot(&p_autograd_outs, 0); + + // Set History for output set current Grad Node for + egr::EagerUtils::SetHistory(&p_autograd_outs, grad_node); + } +} diff --git a/paddle/fluid/eager/to_static/run_program_op_node.h b/paddle/fluid/eager/to_static/run_program_op_node.h index 72c61c1723a3b..8f6f8cbbc22fc 100644 --- a/paddle/fluid/eager/to_static/run_program_op_node.h +++ b/paddle/fluid/eager/to_static/run_program_op_node.h @@ -19,13 +19,16 @@ #include "paddle/fluid/eager/tensor_wrapper.h" #include "paddle/fluid/framework/new_executor/interpretercore.h" #include "paddle/fluid/framework/variable_helper.h" -#include "paddle/fluid/ir/transforms/pd_op_to_kernel_pass.h" #include "paddle/fluid/ir_adaptor/translator/program_translator.h" #include "paddle/fluid/operators/run_program_op.h" +#include "paddle/fluid/pir/transforms/pd_op_to_kernel_pass.h" #include "paddle/fluid/platform/enforce.h" #include "paddle/fluid/platform/profiler/event_tracing.h" -#include "paddle/ir/core/program.h" -#include "paddle/ir/core/value.h" +#include "paddle/pir/core/attribute.h" +#include "paddle/pir/core/block.h" +#include "paddle/pir/core/builtin_attribute.h" +#include "paddle/pir/core/program.h" +#include "paddle/pir/core/value.h" PHI_DECLARE_bool(enable_new_ir_in_executor); @@ -175,6 +178,33 @@ static void ShareTensorsIntoScopeWithName( } } +static auto GetNameFromValue(const ::pir::Block *block, + const std::vector<::pir::Value> &values) { + // we use name here, later value is used directly. + std::unordered_map<::pir::Value, std::string> value2name; + for (auto *op : *block) { + std::string name; + if (op->name() == "pd_op.data") { + name = + op->attributes().at("name").dyn_cast().AsString(); + value2name[op->results()[0].Value::impl()] = name; + } else if (op->name() == "builtin.set_parameter") { + name = op->attributes() + .at("parameter_name") + .dyn_cast() + .AsString(); + value2name[op->operand(0).source()] = name; + } + } + std::vector names; + std::transform( + values.begin(), + values.end(), + std::back_inserter(names), + [&value2name](const ::pir::Value &v) { return value2name[v]; }); + return names; +} + static void ShareTensorsFromScope( const std::vector &tensors, const paddle::framework::BlockDesc &global_block, @@ -216,6 +246,52 @@ static void ShareTensorsFromScope( } } +static void ShareTensorsIntoScopeByValue( + const ::pir::Block *block, + const std::vector &tensors, + const std::vector<::pir::Value> &values, + paddle::framework::Scope *scope) { + auto names = GetNameFromValue(block, values); + ShareTensorsIntoScopeWithName(tensors, names, scope); +} + +static void ShareTensorsFromScopeByValue( + const ::pir::Block *block, + const std::vector &tensors, + const std::vector<::pir::Value> &values, + paddle::framework::Scope *scope) { + auto names = GetNameFromValue(block, values); + for (size_t i = 0; i < tensors.size(); ++i) { + auto &name = names[i]; + auto &value = values[i]; + if (value.impl() == nullptr) { + // skip stop_gradient. + continue; + } + auto *var = scope->FindVar(name); + PADDLE_ENFORCE_NOT_NULL( + var, + paddle::platform::errors::NotFound("The output tensor %s is not in " + "RunProgram(Grad)Op'" + "s internal scope.", + name)); + CheckOutputVarStatus(*var, *tensors[i]); + // share tensor + if (var->IsType()) { + auto &src_tensor = var->Get(); + auto *dst_tensor = const_cast( + dynamic_cast(tensors[i]->impl().get())); + VLOG(2) << "share " << name << " from scope"; + *dst_tensor = src_tensor; + } else if (var->IsType()) { + auto &src_tensor = var->Get(); + auto *dst_tensor = const_cast( + dynamic_cast(tensors[i]->impl().get())); + *dst_tensor = src_tensor; + } + } +} + static void ShareTensorsFromScopeWithPartialBlock( const std::vector &tensors, const paddle::framework::BlockDesc &forward_global_block, @@ -309,8 +385,194 @@ static void GcScope(paddle::framework::Scope *scope) { delete garbages; // free mem } +template +void print_collection(const T &t) { + VLOG(5) << "Print collection start :"; + for (auto s : t) { + VLOG(5) << s; + } + VLOG(5) << "Print collection end."; +} + } // namespace details +inline void NewIRRunProgramAPI( + const std::vector &x, + const std::vector ¶ms, + std::vector &out, // NOLINT + std::vector &middles, // NOLINT + std::vector &step_scope, // NOLINT + std::vector &dout, // NOLINT + bool require_any_grad, + const paddle::framework::AttributeMap &attrs) { + VLOG(2) << "RunProgramOpKernel Compute"; + // In the original run_program OP, the default value of the is_test + // attribute is false, we should check if there is is_test parameter + // in attrs + auto is_test = false; + if (attrs.count("is_test")) { + is_test = PADDLE_GET_CONST(bool, attrs.at("is_test")); + } + int64_t program_id = PADDLE_GET_CONST(int64_t, attrs.at("program_id")); + auto place = egr::Controller::Instance().GetExpectedPlace(); + + // NOTE(chenweihang): In order not to add new variable type, use vector + // here. Originally, here can use scope directly. + auto *out_scope_vec = &step_scope; + PADDLE_ENFORCE_EQ( + out_scope_vec->size(), + 1, + paddle::platform::errors::InvalidArgument( + "The OutScope of RunProgramGradOp should only hold one scope.")); + + VLOG(2) << "RunProgramOp use interpretercore to execute program."; + + paddle::framework::Scope *global_inner_scope = out_scope_vec->front(); + + VLOG(4) << "global_inner_scope:" << global_inner_scope; + + auto input_values = + PADDLE_GET_CONST(std::vector<::pir::Value>, attrs.at("fx")); + auto output_values = + PADDLE_GET_CONST(std::vector<::pir::Value>, attrs.at("fo")); + auto middle_values = + PADDLE_GET_CONST(std::vector<::pir::Value>, attrs.at("fm")); + auto param_values = + PADDLE_GET_CONST(std::vector<::pir::Value>, attrs.at("fp")); + // auto dout_names = + // PADDLE_GET_CONST(std::vector<::pir::Value>, attrs.at("fp")); + + auto *forward_global_block = + PADDLE_GET_CONST(::pir::Block *, attrs.at("forward_global_block")); + auto *backward_global_block = + PADDLE_GET_CONST(::pir::Block *, attrs.at("backward_global_block")); + + auto *forward_program = + forward_global_block->GetParentOp()->GetParentProgram(); + auto *backward_program = + backward_global_block->GetParentOp()->GetParentProgram(); + + if (VLOG_IS_ON(4)) { + std::ostringstream print_stream; + forward_program->Print(print_stream); + print_stream << "\n"; + backward_program->Print(print_stream); + VLOG(4) << print_stream.str(); + } + + VLOG(10) << is_test << program_id; + + auto &interpretercore_info_cache = + paddle::framework::InterpreterCoreInfoCache::Instance(); + std::shared_ptr interpreter_core = + nullptr; + if (!interpretercore_info_cache.Has( + program_id, global_inner_scope, /*is_grad=*/false)) { + paddle::platform::RecordEvent record_event( + "create_new_interpretercore", + paddle::platform::TracerEventType::UserDefined, + 1); + VLOG(2) << "No interpretercore cache, so create a new interpretercore " + "for program: " + << program_id; + // Step 1. share input_vars & parameters into scope + details::ShareTensorsIntoScopeByValue( + forward_global_block, x, input_values, global_inner_scope); + details::ShareTensorsIntoScopeByValue( + forward_global_block, params, param_values, global_inner_scope); + // Step 2. create new interpretercore + auto kernel_forward_program = + paddle::dialect::PdOpLowerToKernelPass(forward_program, place); + interpreter_core = paddle::framework::CreateNewIRInterpreterCoreInfoToCache( + std::move(kernel_forward_program), + place, + /*is_grad=*/false, + program_id, + global_inner_scope); + // Step 3. get all eager gc vars + // std::set skip_eager_delete_vars = + // paddle::framework::details::ParseSafeEagerDeletionSkipVarsSet( + // *backward_program); + + // update interpretercore skip_gc_var + auto skip_names = + details::GetNameFromValue(forward_global_block, middle_values); + auto skip_names_set = + std::set(skip_names.begin(), skip_names.end()); + skip_names = details::GetNameFromValue(forward_global_block, output_values); + skip_names_set.insert(skip_names.begin(), skip_names.end()); + details::print_collection(skip_names_set); + interpreter_core->SetSkipGcVars(skip_names_set); + + // std::set input_vars; + // input_vars.insert(input_names.begin(), input_names.end()); + // interpreter_core->SetJitInputVars(input_vars); + + // interpretercore_info_cache.UpdateSkipEagerDeleteVars( + // program_id, global_inner_scope, false, skip_eager_delete_vars); + } else { + paddle::platform::RecordEvent record_event( + "get_interpretercore_cahce", + paddle::platform::TracerEventType::UserDefined, + 1); + VLOG(2) << "Get interpretercore cache by program:" << program_id; + // Step 1. get cache interpretercore + auto &cached_value = interpretercore_info_cache.GetMutable( + program_id, global_inner_scope, /*is_grad=*/false); + interpreter_core = cached_value.core_; + // Step 2. update scope for cache interpretercore + details::ShareTensorsIntoScopeByValue( + forward_global_block, x, input_values, global_inner_scope); + details::ShareTensorsIntoScopeByValue( + forward_global_block, params, param_values, global_inner_scope); + // TODO(xiongkun): new ir how to build scope. + // if (interpreter_core->GetVariableScope()->GetMutableScope() != + // global_inner_scope) { + // details::BuildScopeByBlock( + // *interpreter_core.get(), *forward_global_block, global_inner_scope); + // interpreter_core->reset_scope(global_inner_scope); + //} + } + + // interpretercore run + if (!forward_global_block->empty()) { + paddle::platform::RecordEvent record_event( + "interpreter_core_run", + paddle::platform::TracerEventType::UserDefined, + 1); + interpreter_core->Run({}); + } + + { + paddle::platform::RecordEvent record_event( + "fetch_and_gc", paddle::platform::TracerEventType::UserDefined, 1); + // Get Output, and Middle Outputs + details::ShareTensorsFromScopeByValue( + forward_global_block, out, output_values, global_inner_scope); + details::ShareTensorsFromScopeByValue( + forward_global_block, middles, middle_values, global_inner_scope); + + VLOG(3) << paddle::framework::GenScopeTreeDebugInfo(out_scope_vec->front()); + + if (is_test || !require_any_grad) { + VLOG(4) << "don't require any grad, set this scope can reused"; + VLOG(4) << "is_test: " << is_test + << ", require_any_grad: " << require_any_grad; + global_inner_scope->SetCanReused(true); + details::GcScope(global_inner_scope); + } else { + VLOG(4) << "not test, set this scope can not reused"; + global_inner_scope->SetCanReused(false); + details::GcScope(global_inner_scope); // we can gc all the time, because + // we save the middles. + } + } + +#ifdef PADDLE_WITH_DNNL + if (FLAGS_use_mkldnn) paddle::platform::DontClearMKLDNNCache(place); +#endif +} + inline void RunProgramAPI( const std::vector &x, const std::vector ¶ms, @@ -403,8 +665,13 @@ inline void RunProgramAPI( if (FLAGS_enable_new_ir_in_executor) { // build new ir program - auto ir_program = paddle::framework::ConstructFowardIrProgram( - forward_global_block, backward_global_block, output_names, x, params); + auto ir_program = + paddle::framework::ConstructFowardIrProgram(forward_global_block, + backward_global_block, + output_names, + x, + params, + place); interpreter_core = paddle::framework::CreateNewIRInterpreterCoreInfoToCache( std::move(ir_program), @@ -660,12 +927,164 @@ inline void RunProgramGradAPI( } } +inline void NewIRRunProgramGradAPI( + const std::vector &x, + const std::vector ¶ms, + const std::vector &out_grad, + const std::vector &middles, + const std::vector &out, + const std::vector &step_scope, // NOLINT + const paddle::framework::AttributeMap &attrs, + std::vector &x_grad, // NOLINT + std::vector ¶ms_grad // NOLINT +) { + // if all output vars are set to stop_gradient, grad op no need to executed + if (x_grad.empty() && params_grad.empty()) return; + auto *out_scope_vec = &step_scope; + PADDLE_ENFORCE_EQ( + out_scope_vec->size(), + 1, + paddle::platform::errors::InvalidArgument( + "The OutScope of RunProgramGradOp should only hold one scope.")); + paddle::framework::Scope *global_inner_scope = out_scope_vec->front(); + + int64_t program_id = PADDLE_GET_CONST(int64_t, attrs.at("program_id")); + + auto place = egr::Controller::Instance().GetExpectedPlace(); + VLOG(2) << "RunProgramGradOp use interpretercore to execute program."; + + VLOG(4) << "global_inner_scope:" << global_inner_scope; + + auto *backward_global_block = + PADDLE_GET_CONST(::pir::Block *, attrs.at("backward_global_block")); + auto *backward_program = + backward_global_block->GetParentOp()->GetParentProgram(); + + auto output_grad_values = + PADDLE_GET_CONST(std::vector<::pir::Value>, attrs.at("bo_g")); + auto forward_input_values = + PADDLE_GET_CONST(std::vector<::pir::Value>, attrs.at("bx")); + auto forward_middle_values = + PADDLE_GET_CONST(std::vector<::pir::Value>, attrs.at("bm")); + auto forward_output_values = + PADDLE_GET_CONST(std::vector<::pir::Value>, attrs.at("bo")); + auto x_grad_values = + PADDLE_GET_CONST(std::vector<::pir::Value>, attrs.at("bx_g")); + auto p_grad_values = + PADDLE_GET_CONST(std::vector<::pir::Value>, attrs.at("bp_g")); + + auto &interpretercore_info_cache = + paddle::framework::InterpreterCoreInfoCache::Instance(); + std::shared_ptr interpreter_core = + nullptr; + if (!interpretercore_info_cache.Has( + program_id, global_inner_scope, /*is_grad=*/true)) { + paddle::platform::RecordEvent record_event( + "create_new_interpretercore", + paddle::platform::TracerEventType::UserDefined, + 1); + VLOG(2) << "No interpretercore cahce, so create a new interpretercore"; + // Step 1. share input_vars & parameters into scope + // x, param, middles, output_grads + details::ShareTensorsIntoScopeByValue(backward_global_block, + out_grad, + output_grad_values, + global_inner_scope); + details::ShareTensorsIntoScopeByValue( + backward_global_block, x, forward_input_values, global_inner_scope); + details::ShareTensorsIntoScopeByValue(backward_global_block, + middles, + forward_middle_values, + global_inner_scope); + details::ShareTensorsIntoScopeByValue( + backward_global_block, out, forward_output_values, global_inner_scope); + auto kernel_backward_program = + paddle::dialect::PdOpLowerToKernelPass(backward_program, place); + interpreter_core = paddle::framework::CreateNewIRInterpreterCoreInfoToCache( + std::move(kernel_backward_program), + place, + /*is_grad=*/true, + program_id, + global_inner_scope); + // share threadpool + // NOTE(zhiqiu): this only works interpreter_core is executed strictly + // after the related fwd_interpreter_core. + if (interpretercore_info_cache.Has(program_id, global_inner_scope, false)) { + auto fwd_interpreter_core = + interpretercore_info_cache + .GetMutable(program_id, global_inner_scope, /*is_grad=*/false) + .core_; + interpreter_core->ShareWorkQueueFrom(fwd_interpreter_core); + VLOG(4) << "Share workqueue from " << fwd_interpreter_core.get() << " to " + << interpreter_core.get(); + } + + // get all eager gc vars + std::set skip_eager_delete_vars; + auto skip_names = + details::GetNameFromValue(backward_global_block, x_grad_values); + skip_eager_delete_vars.insert(skip_names.begin(), skip_names.end()); + skip_names = + details::GetNameFromValue(backward_global_block, p_grad_values); + skip_eager_delete_vars.insert(skip_names.begin(), skip_names.end()); + interpreter_core->SetSkipGcVars(skip_eager_delete_vars); + interpretercore_info_cache.UpdateSkipEagerDeleteVars( + program_id, + global_inner_scope, + /*is_grad=*/true, + skip_eager_delete_vars); + VLOG(2) << "Get skip GC vars size is: " << skip_eager_delete_vars.size(); + details::print_collection(skip_eager_delete_vars); + } else { + paddle::platform::RecordEvent record_event( + "get_interpretercore_cahce", + paddle::platform::TracerEventType::UserDefined, + 1); + VLOG(2) << "Get interpretercore cahce by program:" << program_id; + auto &cached_value = interpretercore_info_cache.GetMutable( + program_id, global_inner_scope, /*is_grad=*/true); + interpreter_core = cached_value.core_; + + // update scope (TODO: why share again) + // details::ShareTensorsIntoScope(out_grad, global_inner_scope); + // if (interpreter_core->GetVariableScope()->GetMutableScope() != + // global_inner_scope) { + // details::BuildScopeByBlock( + // *interpreter_core.get(), *backward_global_block, global_inner_scope); + // interpreter_core->reset_scope(global_inner_scope); + //} + } + + if (!backward_global_block->empty()) { + paddle::platform::RecordEvent record_event( + "interpreter_core_run", + paddle::platform::TracerEventType::UserDefined, + 1); + // Debug info: scope info when run end + VLOG(3) << paddle::framework::GenScopeTreeDebugInfo(out_scope_vec->front()); + interpreter_core->Run({}); + } + + { + paddle::platform::RecordEvent record_event( + "fetch_and_gc", paddle::platform::TracerEventType::UserDefined, 1); + // Step 4. get outputs + details::ShareTensorsFromScopeByValue( + backward_global_block, x_grad, x_grad_values, global_inner_scope); + details::ShareTensorsFromScopeByValue( + backward_global_block, params_grad, p_grad_values, global_inner_scope); + VLOG(4) << "after backward gc all vars"; + global_inner_scope->SetCanReused(true); + details::GcScope(global_inner_scope); + } +} + class GradNodeRunProgram : public egr::GradNodeBase { public: GradNodeRunProgram(size_t bwd_in_slot_num, size_t bwd_out_slot_num) : egr::GradNodeBase(bwd_in_slot_num, bwd_out_slot_num) {} - ~GradNodeRunProgram() { + ~GradNodeRunProgram() override { if (!executed_) { auto *out_scope_vec = &step_scope_; VLOG(4) << "~GradNodeRunProgram"; @@ -828,3 +1247,187 @@ class GradNodeRunProgram : public egr::GradNodeBase { bool executed_{false}; }; + +class NewIRGradNodeRunProgram : public egr::GradNodeBase { + public: + NewIRGradNodeRunProgram(size_t bwd_in_slot_num, size_t bwd_out_slot_num) + : egr::GradNodeBase(bwd_in_slot_num, bwd_out_slot_num) {} + + ~NewIRGradNodeRunProgram() override { + if (!executed_) { + auto *out_scope_vec = &step_scope_; + VLOG(4) << "~GradNodeRunProgram"; + // Normally out_scope_vec.size() == 1. for safty, we add for-loop here. + for (size_t i = 0; i < out_scope_vec->size(); ++i) { + paddle::framework::Scope *global_inner_scope = out_scope_vec->at(i); + global_inner_scope->SetCanReused(true); + details::GcScope(global_inner_scope); + VLOG(4) << "global_inner_scope SetCanReused"; + } + middles_.clear(); + outputs_.clear(); + } + } + // Functor: perform backward computations + virtual paddle::small_vector, + egr::kSlotSmallVectorSize> + operator()(paddle::small_vector, + egr::kSlotSmallVectorSize> &grads, // NOLINT + bool create_graph UNUSED, + bool is_new_grad UNUSED) override { + VLOG(3) << "Running Eager Backward Node: GradNodeRunProgram"; + paddle::small_vector, egr::kSlotSmallVectorSize> + hooked_grads = NewIRGradNodeRunProgram::ApplyGradientHooks(grads); + PADDLE_ENFORCE_EQ(hooked_grads.size(), + 1, + paddle::platform::errors::InvalidArgument( + "The hooked_grads.size() of RunProgramGradOp should " + "be equal to 1.")); + + std::vector x_grad; + std::vector params_grad; + std::vector x_grad_ptr; + std::vector params_grad_ptr; + { + paddle::platform::RecordEvent record_event( + "construct_grad_tensor", + paddle::platform::TracerEventType::UserDefined, + 1); + + egr::EagerUtils::FillZeroForEmptyOptionalGradInput(&hooked_grads[0], + this->InputMeta()[0]); + VLOG(3) << "hooked_grads[0].size() : " << hooked_grads[0].size(); + ConstructXGradTensors(x_, &x_grad); + ConstructParamGradTensors(params_, ¶ms_grad); + for (auto &i : x_grad) { + x_grad_ptr.emplace_back(&i); + } + for (auto &i : params_grad) { + if (i.defined()) { + params_grad_ptr.emplace_back(&i); + } + } + } + + auto out_grad_values = + PADDLE_GET_CONST(std::vector<::pir::Value>, attrs_.at("bo_g")); + PADDLE_ENFORCE_EQ(hooked_grads[0].size(), + out_grad_values.size(), + paddle::platform::errors::InvalidArgument( + "The hooked_grads[0].size() and " + "out_grad_values.size() should be equal.")); + + VLOG(1) << "Run Program Grad API start."; + NewIRRunProgramGradAPI(x_, + params_, + hooked_grads[0], + middles_, + outputs_, + step_scope_, + attrs_, + x_grad_ptr, + params_grad_ptr); + VLOG(1) << "Run Program Grad API end."; + VLOG(3) << "End Eager Backward Node: GradNodeRunProgram"; + + executed_ = true; + return {x_grad, params_grad}; + } + + void ClearTensorWrappers() override { + x_.clear(); + params_.clear(); + middles_.clear(); + outputs_.clear(); + SetIsTensorWrappersCleared(true); + } + + // SetAttrMap + void SetAttrMap(const paddle::framework::AttributeMap &attrs) { + attrs_ = attrs; + } + + void SetFwdX(const std::vector &tensors) { x_ = tensors; } + + std::vector &GetMiddle() { return middles_; } + + std::vector &GetOutputs() { return outputs_; } + + void SetFwdParams(const std::vector &tensors) { + params_ = tensors; + } + + void SetStepScope(const std::vector &scopes) { + step_scope_ = scopes; + } + + protected: + void ConstructXGradTensors(const std::vector &x, + std::vector *x_grad) { + auto x_grad_values = + PADDLE_GET_CONST(std::vector<::pir::Value>, attrs_.at("bx_g")); + PADDLE_ENFORCE_EQ( + x.size(), + x_grad_values.size(), + paddle::platform::errors::InvalidArgument( + "The x.size() and x_grad_names.size() should be equal. " + "But received x.size() = %d, x_grad_names.size() = %d", + x.size(), + x_grad_values.size())); + + // TODO(dev): Need an elegant way to determine inforamtion of grad_tensor, + // such as: name, tensor type(DenseTensor or SelectedRows). + for (size_t i = 0; i < x.size(); i++) { + if (x[i].is_dense_tensor()) { + x_grad->emplace_back(std::make_shared()); + } else if (x[i].is_selected_rows()) { + x_grad->emplace_back(std::make_shared()); + } + } + } + + void ConstructParamGradTensors(const std::vector ¶ms, + std::vector *param_grads) { + auto p_grad_values = + PADDLE_GET_CONST(std::vector<::pir::Value>, attrs_.at("bp_g")); + PADDLE_ENFORCE_EQ(params.size(), + p_grad_values.size(), + paddle::platform::errors::InvalidArgument( + "The param.size() and " + "param_grad_names.size() should be equal.")); + + for (size_t i = 0; i < params.size(); ++i) { + auto &p = params[i]; + auto &p_grad = egr::EagerUtils::unsafe_autograd_meta(p)->Grad(); + // In eager mode, the number of param_grad should be the same as + // param, so here an empty Tensor is added for the param with + // stop_gradient=True + if (!p_grad.defined()) { + param_grads->emplace_back(); + } else if (p_grad.is_dense_tensor()) { + param_grads->emplace_back(std::make_shared()); + } else if (p_grad.is_selected_rows()) { + param_grads->emplace_back(std::make_shared()); + } + } + } + + std::shared_ptr Copy() const override { + auto copied_node = std::shared_ptr( + new NewIRGradNodeRunProgram(*this)); + return copied_node; + } + + private: + // TensorWrappers + std::vector x_; + std::vector params_; + std::vector middles_; + std::vector outputs_; + std::vector step_scope_; + + // Attribute Map + paddle::framework::AttributeMap attrs_; + + bool executed_{false}; +}; diff --git a/paddle/fluid/framework/CMakeLists.txt b/paddle/fluid/framework/CMakeLists.txt index 3befea7d0fd2b..f72d4ad182ddd 100755 --- a/paddle/fluid/framework/CMakeLists.txt +++ b/paddle/fluid/framework/CMakeLists.txt @@ -696,6 +696,7 @@ cc_library( DEPS while_op_helper recurrent_op_helper conditional_block_op_helper + pylayer_op_helper scope proto_desc operator @@ -1014,8 +1015,9 @@ else() monitor) endif() -target_link_libraries(executor while_op_helper executor_gc_helper - recurrent_op_helper conditional_block_op_helper) +target_link_libraries( + executor while_op_helper executor_gc_helper recurrent_op_helper + conditional_block_op_helper pylayer_op_helper) cc_library( parallel_executor @@ -1035,7 +1037,7 @@ cc_library( executor_cache SRCS executor_cache.cc DEPS parallel_executor standalone_executor phi_kernel_adaptor pd_inplace_pass - pd_op_to_kernel_pass ir) + pd_op_to_kernel_pass pir) if(WITH_PSCORE) get_property(RPC_DEPS GLOBAL PROPERTY RPC_DEPS) if(WITH_HETERPS) diff --git a/paddle/fluid/framework/custom_operator.cc b/paddle/fluid/framework/custom_operator.cc index c2dd1bf37dd19..8814935e3fceb 100644 --- a/paddle/fluid/framework/custom_operator.cc +++ b/paddle/fluid/framework/custom_operator.cc @@ -947,12 +947,10 @@ static void RegisterOperatorKernel( #ifdef PADDLE_WITH_CUSTOM_DEVICE auto device_types = phi::DeviceManager::GetAllCustomDeviceTypes(); for (const auto& dev_type : device_types) { - for (auto& dev_id : phi::DeviceManager::GetSelectedDeviceList(dev_type)) { - RegisterOperatorKernelWithPlace(name, - op_kernel_func, - proto::VarType::RAW, - platform::CustomPlace(dev_type, dev_id)); - } + RegisterOperatorKernelWithPlace(name, + op_kernel_func, + proto::VarType::RAW, + platform::CustomPlace(dev_type)); } #endif } diff --git a/paddle/fluid/framework/executor_cache.cc b/paddle/fluid/framework/executor_cache.cc index f5c4c745cfd51..c03c8542b49c6 100644 --- a/paddle/fluid/framework/executor_cache.cc +++ b/paddle/fluid/framework/executor_cache.cc @@ -16,14 +16,14 @@ #include "paddle/fluid/framework/new_executor/interpretercore.h" #include "paddle/fluid/framework/op_info.h" -#include "paddle/fluid/ir/transforms/inplace_pass.h" -#include "paddle/fluid/ir/transforms/pd_op_to_kernel_pass.h" #include "paddle/fluid/ir_adaptor/translator/translate.h" -#include "paddle/ir/core/program.h" -#include "paddle/ir/core/value.h" -#include "paddle/ir/pass/pass.h" -#include "paddle/ir/pass/pass_manager.h" +#include "paddle/fluid/pir/transforms/inplace_pass.h" +#include "paddle/fluid/pir/transforms/pd_op_to_kernel_pass.h" #include "paddle/phi/core/flags.h" +#include "paddle/pir/core/program.h" +#include "paddle/pir/core/value.h" +#include "paddle/pir/pass/pass.h" +#include "paddle/pir/pass/pass_manager.h" PHI_DECLARE_bool(new_ir_apply_inplace_pass); @@ -304,7 +304,7 @@ std::shared_ptr CreateProgramInterpreterCoreInfoToCache( framework::Scope *scope) { auto &interpretercore_info_cache = framework::InterpreterCoreInfoCache::Instance(); - if (interpretercore_info_cache.Size() > 10u /* max_cached_size*/) { + if (interpretercore_info_cache.Size() > 256u /* max_cached_size*/) { VLOG(2) << "The cached info size has exceeded max_cached_size: 4, clear " "all cache!"; interpretercore_info_cache.Finalize(); @@ -325,14 +325,14 @@ std::shared_ptr CreateProgramInterpreterCoreInfoToCache( } std::shared_ptr CreateNewIRInterpreterCoreInfoToCache( - std::unique_ptr<::ir::Program> ir_program, + std::unique_ptr<::pir::Program> ir_program, const platform::Place &place, bool is_grad, int64_t program_id, framework::Scope *scope) { auto &interpretercore_info_cache = framework::InterpreterCoreInfoCache::Instance(); - if (interpretercore_info_cache.Size() > 10u /* max_cached_size*/) { + if (interpretercore_info_cache.Size() > 256u /* max_cached_size*/) { VLOG(2) << "The cached info size has exceeded max_cached_size: 4, clear " "all cache!"; interpretercore_info_cache.Finalize(); @@ -352,14 +352,15 @@ std::shared_ptr CreateNewIRInterpreterCoreInfoToCache( return core; } -std::unique_ptr<::ir::Program> ConstructFowardIrProgram( +std::unique_ptr<::pir::Program> ConstructFowardIrProgram( const paddle::framework::BlockDesc *forward_global_block, const paddle::framework::BlockDesc *backward_global_block, const std::vector output_names, const std::vector &x, - const std::vector ¶ms) { - auto ir_ctx = ::ir::IrContext::Instance(); - auto program = std::make_unique<::ir::Program>(ir_ctx); + const std::vector ¶ms, + const phi::Place &place) { + auto ir_ctx = ::pir::IrContext::Instance(); + auto program = std::make_unique<::pir::Program>(ir_ctx); std::set set_output_names; auto local_program = @@ -381,14 +382,14 @@ std::unique_ptr<::ir::Program> ConstructFowardIrProgram( if (block->FindVarRecursive(name) == nullptr) { continue; } - auto place = in_t.place().GetType(); + auto p = in_t.place().GetType(); auto op_desc = block->PrependOp(); op_desc->SetType("data"); op_desc->SetAttr("shape", std::vector()); // TODO(phlrain) : using tensor dtype op_desc->SetAttr("dtype", 0); - op_desc->SetAttr("place", static_cast(place)); + op_desc->SetAttr("place", static_cast(p)); op_desc->SetAttr("name", name); op_desc->SetOutput("out", {name}); } @@ -396,14 +397,14 @@ std::unique_ptr<::ir::Program> ConstructFowardIrProgram( std::set input_param_names; for (auto ¶m : params) { auto &name = param.name(); - auto place = param.place().GetType(); + auto p = param.place().GetType(); auto op_desc = local_program.MutableBlock(0)->PrependOp(); op_desc->SetType("data"); op_desc->SetAttr("shape", std::vector()); // TODO(phlrain) : using tensor dtype op_desc->SetAttr("dtype", 0); - op_desc->SetAttr("place", static_cast(place)); + op_desc->SetAttr("place", static_cast(p)); op_desc->SetAttr("name", name); op_desc->SetOutput("out", {name}); @@ -445,25 +446,25 @@ std::unique_ptr<::ir::Program> ConstructFowardIrProgram( program_translator.Translate(); - auto ir_res = paddle::dialect::PdOpLowerToKernelPass(program.get()); + auto ir_res = paddle::dialect::PdOpLowerToKernelPass(program.get(), place); if (FLAGS_new_ir_apply_inplace_pass) { - ::ir::PassManager pm(::ir::IrContext::Instance(), 3); - pm.AddPass(::ir::CreateInplacePass()); + ::pir::PassManager pm(::pir::IrContext::Instance(), 3); + pm.AddPass(::pir::CreateInplacePass()); pm.Run(ir_res.get()); } return ir_res; } -std::unique_ptr<::ir::Program> ConstructBackwardIrProgram( +std::unique_ptr<::pir::Program> ConstructBackwardIrProgram( const paddle::framework::BlockDesc *backward_global_block, const std::vector &out_grad, const std::vector &x_grad, const std::vector ¶ms_grad, const paddle::framework::Scope *scope) { - auto ir_ctx = ::ir::IrContext::Instance(); - auto program = std::make_unique<::ir::Program>(ir_ctx); + auto ir_ctx = ::pir::IrContext::Instance(); + auto program = std::make_unique<::pir::Program>(ir_ctx); auto local_program = paddle::framework::ProgramDesc(*(backward_global_block->Program())); @@ -527,8 +528,8 @@ std::unique_ptr<::ir::Program> ConstructBackwardIrProgram( auto res = paddle::dialect::PdOpLowerToKernelPass(program.get()); if (FLAGS_new_ir_apply_inplace_pass) { - ::ir::PassManager pm(::ir::IrContext::Instance(), 3); - pm.AddPass(::ir::CreateInplacePass()); + ::pir::PassManager pm(::pir::IrContext::Instance(), 3); + pm.AddPass(::pir::CreateInplacePass()); pm.Run(res.get()); } diff --git a/paddle/fluid/framework/executor_cache.h b/paddle/fluid/framework/executor_cache.h index edbbc0e9420af..1c5602a31f872 100644 --- a/paddle/fluid/framework/executor_cache.h +++ b/paddle/fluid/framework/executor_cache.h @@ -30,9 +30,9 @@ #include "paddle/fluid/string/string_helper.h" #include "paddle/fluid/ir_adaptor/translator/program_translator.h" -#include "paddle/ir/core/dialect.h" -#include "paddle/ir/core/ir_context.h" -#include "paddle/ir/core/program.h" +#include "paddle/pir/core/dialect.h" +#include "paddle/pir/core/ir_context.h" +#include "paddle/pir/core/program.h" PHI_DECLARE_bool(enable_new_ir_in_executor); namespace paddle { @@ -243,20 +243,21 @@ std::shared_ptr CreateProgramInterpreterCoreInfoToCache( framework::Scope* scope); std::shared_ptr CreateNewIRInterpreterCoreInfoToCache( - std::unique_ptr<::ir::Program> ir_prog, + std::unique_ptr<::pir::Program> ir_prog, const platform::Place& place, bool is_grad, int64_t program_id, framework::Scope* scope); -std::unique_ptr<::ir::Program> ConstructFowardIrProgram( +std::unique_ptr<::pir::Program> ConstructFowardIrProgram( const paddle::framework::BlockDesc* forward_global_block, const paddle::framework::BlockDesc* backward_global_block, const std::vector output_names, const std::vector& x, - const std::vector& params); + const std::vector& params, + const phi::Place& place); -std::unique_ptr<::ir::Program> ConstructBackwardIrProgram( +std::unique_ptr<::pir::Program> ConstructBackwardIrProgram( const paddle::framework::BlockDesc* backward_global_block, const std::vector& out_grad, const std::vector& x_grad, diff --git a/paddle/fluid/framework/executor_gc_helper.cc b/paddle/fluid/framework/executor_gc_helper.cc index fa63b4bca16ea..27342b123d6e9 100644 --- a/paddle/fluid/framework/executor_gc_helper.cc +++ b/paddle/fluid/framework/executor_gc_helper.cc @@ -24,6 +24,7 @@ #include "paddle/fluid/framework/operator.h" #include "paddle/fluid/framework/var_desc.h" #include "paddle/fluid/operators/controlflow/conditional_block_op_helper.h" +#include "paddle/fluid/operators/controlflow/pylayer_op_helper.h" #include "paddle/fluid/operators/controlflow/recurrent_op_helper.h" #include "paddle/fluid/operators/controlflow/while_op_helper.h" #include "paddle/fluid/platform/enforce.h" @@ -226,6 +227,8 @@ GetEagerDeletionCleanVarsForPartial(const ProgramDesc &origin_program, auto global_block_ops = CreateOpsFromBlock(program.Block(0)); operators::PrepareSafeEagerDeletionOnConditionalOpAndConditionalGradOp( program, 0, global_block_ops); + operators::PrepareSafeEagerDeletionOnPyLayerOpAndPyLayerGradOp( + program, 0, global_block_ops); operators::PrepareSafeEagerDeletionOnWhileOpAndWhileGradOp( program, 0, global_block_ops); operators::PrepareSafeEagerDeletionOnRecurrentOpAndRecurrentGradOp( @@ -240,35 +243,54 @@ GetEagerDeletionCleanVarsForPartial(const ProgramDesc &origin_program, const char *kSubBlock = "sub_block"; const char *kSkipEagerDeletionVars = "skip_eager_deletion_vars"; + // NOTE: pylayer op contains may contain two blocks: forward block and + // backward block + const char *kBlocks = "blocks"; for (size_t i = 0; i < block_num; ++i) { const auto &block = program.Block(i); size_t op_num = block.OpSize(); for (size_t j = 0; j < op_num; ++j) { auto *op = block.Op(static_cast(j)); - if (!op->HasAttr(kSubBlock) || !op->HasAttr(kSkipEagerDeletionVars)) { + if ((!op->HasAttr(kSubBlock) && !op->HasAttr(kBlocks)) || + !op->HasAttr(kSkipEagerDeletionVars)) { continue; } - auto sub_block_id = op->GetAttrIfExists(kSubBlock)->ID(); - PADDLE_ENFORCE_GE(sub_block_id, - 0, - platform::errors::PermissionDenied( - "sub_block id must be non-negative number")); - PADDLE_ENFORCE_LT(sub_block_id, - block_num, - platform::errors::PermissionDenied( - "sub_block id exceeds max block num")); - PADDLE_ENFORCE_EQ( - found_skip_vars[sub_block_id], - false, - platform::errors::PermissionDenied( - "there are 2 ops which refer to the same sub_block %d", - sub_block_id)); - - found_skip_vars[sub_block_id] = true; - auto sub_block_skip_vars = - op->GetAttrIfExists>(kSkipEagerDeletionVars); - skip_vars_on_each_block[sub_block_id] = std::move(sub_block_skip_vars); + + std::vector sub_block_ids; + if (op->HasAttr(kSubBlock)) { + sub_block_ids.push_back( + op->GetAttrIfExists(kSubBlock)->ID()); + } else if (op->HasAttr(kBlocks)) { + const auto &blocks = + op->GetAttrIfExists>(kBlocks); + for (const auto &block : blocks) { + sub_block_ids.push_back(block->ID()); + } + } + + for (auto sub_block_id : sub_block_ids) { + PADDLE_ENFORCE_GE(sub_block_id, + 0, + platform::errors::PermissionDenied( + "sub_block id must be non-negative number")); + PADDLE_ENFORCE_LT(sub_block_id, + block_num, + platform::errors::PermissionDenied( + "sub_block id exceeds max block num")); + PADDLE_ENFORCE_EQ( + found_skip_vars[sub_block_id], + false, + platform::errors::PermissionDenied( + "there are 2 ops which refer to the same sub_block %d", + sub_block_id)); + + found_skip_vars[sub_block_id] = true; + auto sub_block_skip_vars = + op->GetAttrIfExists>( + kSkipEagerDeletionVars); + skip_vars_on_each_block[sub_block_id] = std::move(sub_block_skip_vars); + } } } diff --git a/paddle/fluid/framework/ir/CMakeLists.txt b/paddle/fluid/framework/ir/CMakeLists.txt index b6143f335d163..526847bb32de5 100755 --- a/paddle/fluid/framework/ir/CMakeLists.txt +++ b/paddle/fluid/framework/ir/CMakeLists.txt @@ -46,7 +46,7 @@ cc_library( cc_library( op_compat_sensible_pass SRCS op_compat_sensible_pass.cc - DEPS graph_pattern_detector op_def_api pass) + DEPS graph_pattern_detector op_def_api pass pir_core) cc_library( subgraph_detector SRCS subgraph_detector.cc @@ -156,6 +156,8 @@ if(WITH_TENSORRT) pass_library(preln_elementwise_groupnorm_act_pass inference) pass_library(groupnorm_act_pass inference) pass_library(trans_layernorm_fuse_pass inference) + pass_library(trt_prompt_tuning_embedding_eltwise_layernorm_fuse_pass + inference) pass_library(trt_embedding_eltwise_layernorm_fuse_pass inference) pass_library(preln_embedding_eltwise_layernorm_fuse_pass inference) pass_library(split_layernorm_to_math_ops_pass inference) @@ -241,8 +243,9 @@ if(WITH_XPU) pass_library(yolo_box_xpu_fuse_pass inference DIR xpu DEPS ${XPU_PASS_DEPS}) pass_library(cast_embedding_trans_ids_to_int32_pass inference DIR xpu DEPS ${XPU_PASS_DEPS}) - pass_library(conv1d_xpu_fuse_pass inference DIR xpu DEPS ${XPU_PASS_DEPS}) + # pass_library(conv1d_xpu_fuse_pass inference DIR xpu DEPS ${XPU_PASS_DEPS}) pass_library(conv2d_xpu_fuse_pass inference DIR xpu DEPS ${XPU_PASS_DEPS}) + pass_library(conv2d_bias_fuse_pass inference DIR xpu DEPS ${XPU_PASS_DEPS}) pass_library(redundant_unsqueeze_squeeze_elimination_pass inference DIR xpu DEPS ${XPU_PASS_DEPS}) pass_library(redundant_squeeze_unsqueeze_elimination_pass inference DIR xpu diff --git a/paddle/fluid/framework/ir/generate_pass.cc b/paddle/fluid/framework/ir/generate_pass.cc index 17d2bdda56cb9..e0ab584ee3225 100644 --- a/paddle/fluid/framework/ir/generate_pass.cc +++ b/paddle/fluid/framework/ir/generate_pass.cc @@ -15,6 +15,8 @@ #include "paddle/fluid/framework/ir/generate_pass.h" #include "paddle/fluid/framework/ir/graph_pattern_detector.h" +#include "paddle/pir/core/block.h" +#include "paddle/pir/core/value.h" #include "paddle/utils/blank.h" namespace paddle { @@ -47,6 +49,12 @@ class element_visitor { int index_; }; +template <> +Attribute element_visitor::operator()( + const std::vector<::pir::Value>& attr UNUSED) const { + PADDLE_THROW(platform::errors::Unimplemented("Unimplemented operand.")); +} + class operation_visitor { public: explicit operation_visitor(const proto::PassDesc::OperationType& type) diff --git a/paddle/fluid/framework/ir/graph_helper.cc b/paddle/fluid/framework/ir/graph_helper.cc index 3e744e18bf6c8..b322e3f8bce28 100644 --- a/paddle/fluid/framework/ir/graph_helper.cc +++ b/paddle/fluid/framework/ir/graph_helper.cc @@ -704,10 +704,18 @@ static void GetGraphOpDesc(const std::vector &nodes, ops->emplace_back(depend_desc); VLOG(4) << "add depend op"; } - if (n->Name() == "while" || n->Name() == "while_grad" || - n->Name() == "conditional_block" || - n->Name() == "conditional_block_grad" || n->Name() == "recurrent" || - n->Name() == "recurrent_grad") { + + const std::unordered_set control_flow_ops = { + "while", + "while_grad", + "conditional_block", + "conditional_block_grad", + "recurrent", + "recurrent_grad", + "pylayer", + "pylayer_grad"}; + + if (control_flow_ops.count(n->Name())) { VLOG(1) << "Update control op attr: skip_eager_deletion_vars"; UpdateControlOpSkipEagerDeletionVars(*n, graph, graph_idx, n->Name()); } diff --git a/paddle/fluid/framework/ir/memory_optimize_pass/CMakeLists.txt b/paddle/fluid/framework/ir/memory_optimize_pass/CMakeLists.txt index ffb1606b95ccd..1e634343c7fc1 100644 --- a/paddle/fluid/framework/ir/memory_optimize_pass/CMakeLists.txt +++ b/paddle/fluid/framework/ir/memory_optimize_pass/CMakeLists.txt @@ -6,6 +6,10 @@ cc_library( conditional_block_op_eager_deletion_pass SRCS conditional_block_op_eager_deletion_pass.cc DEPS conditional_block_op_helper graph_helper pass computation_op_handle) +cc_library( + pylayer_op_eager_deletion_pass + SRCS pylayer_op_eager_deletion_pass.cc + DEPS pylayer_op_helper graph_helper pass computation_op_handle) cc_library( while_op_eager_deletion_pass SRCS while_op_eager_deletion_pass.cc @@ -31,6 +35,7 @@ set(EAGER_DELETETION_PASS_DEPS graph_helper pass conditional_block_op_eager_deletion_pass + pylayer_op_eager_deletion_pass while_op_eager_deletion_pass recurrent_op_eager_deletion_pass reference_count_pass_helper) diff --git a/paddle/fluid/framework/ir/memory_optimize_pass/eager_deletion_pass.cc b/paddle/fluid/framework/ir/memory_optimize_pass/eager_deletion_pass.cc index b0729abfcf883..40525a14141a6 100644 --- a/paddle/fluid/framework/ir/memory_optimize_pass/eager_deletion_pass.cc +++ b/paddle/fluid/framework/ir/memory_optimize_pass/eager_deletion_pass.cc @@ -294,6 +294,10 @@ void EagerDeletionPass::ApplyImpl(ir::Graph *graph) const { "conditional_block_op_eager_deletion_pass"); conditional_block_op_eager_deletion_pass->Apply(graph); + auto pylayer_op_eager_deletion_pass = + ir::PassRegistry::Instance().Get("pylayer_op_eager_deletion_pass"); + pylayer_op_eager_deletion_pass->Apply(graph); + auto while_op_eager_deletion_pass = ir::PassRegistry::Instance().Get("while_op_eager_deletion_pass"); while_op_eager_deletion_pass->Apply(graph); @@ -321,6 +325,7 @@ REGISTER_PASS(eager_deletion_pass, paddle::framework::ir::EagerDeletionPass) .RequirePassAttr(paddle::framework::ir::kGarbageCollector); USE_PASS(conditional_block_op_eager_deletion_pass); +USE_PASS(pylayer_op_eager_deletion_pass); USE_PASS(while_op_eager_deletion_pass); USE_PASS(recurrent_op_eager_deletion_pass); #ifdef PADDLE_WITH_CINN diff --git a/paddle/fluid/framework/ir/memory_optimize_pass/pylayer_op_eager_deletion_pass.cc b/paddle/fluid/framework/ir/memory_optimize_pass/pylayer_op_eager_deletion_pass.cc new file mode 100644 index 0000000000000..6d2fe78ea1d12 --- /dev/null +++ b/paddle/fluid/framework/ir/memory_optimize_pass/pylayer_op_eager_deletion_pass.cc @@ -0,0 +1,102 @@ +// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/fluid/framework/details/computation_op_handle.h" +#include "paddle/fluid/framework/ir/graph_helper.h" +#include "paddle/fluid/framework/ir/pass.h" +#include "paddle/fluid/operators/controlflow/op_variant.h" +#include "paddle/fluid/operators/controlflow/pylayer_op_helper.h" +namespace paddle { +namespace framework { +namespace ir { +using OpVariant = operators::OpVariant; +class PyLayerOpEagerDeletionPass : public Pass { + protected: + void ApplyImpl(Graph *graph) const override { + auto all_ops = ir::FilterByNodeWrapper(*graph); + + // Find all pylayer_op and pylayer_grad_op + std::unordered_map< + size_t, + std::pair, std::vector>> + target_ops; + for (auto *op : all_ops) { + auto compute_op = dynamic_cast(op); + if (compute_op == nullptr) continue; + + if (compute_op->Name() == "pylayer") { + target_ops[compute_op->GetScopeIdx()].first.emplace_back( + compute_op->GetOp()); + } else if (compute_op->Name() == "pylayer_grad") { + target_ops[compute_op->GetScopeIdx()].second.emplace_back( + compute_op->GetOp()); + } + } + + // NOTE(Aurelius84): In case of @to_static, after we finish executing + // forward graph, some necessaray variable in step_scope of pylayer_op + // should be kept for backward graph. + if (graph->IsConstructedByPartialProgram()) { + PADDLE_ENFORCE_LE(target_ops.size(), + 1, + platform::errors::InvalidArgument( + "Unsupported multi devices if graph is constructed " + "with partial program.")); + size_t scope_idx = 0; + auto &pylayer_ops = target_ops[scope_idx].first; + auto &pylayer_grad_ops = target_ops[scope_idx].second; + + auto all_ops = graph->OriginProgram().Block(0).AllOps(); + if (pylayer_ops.empty()) { + operators::AppendOpVariantByOpName( + all_ops, std::string("pylayer"), &pylayer_ops); + } else if (pylayer_grad_ops.empty()) { + operators::AppendOpVariantByOpName( + all_ops, std::string("pylayer_grad"), &pylayer_grad_ops); + } else { + PADDLE_THROW("One of pylayer_ops or pylayer_grad_ops should be empty."); + } + } + + for (auto &ops_pair : target_ops) { + auto &pylayer_ops = ops_pair.second.first; + auto &pylayer_grad_ops = ops_pair.second.second; + operators::PrepareSafeEagerDeletionOnPyLayerOpAndPyLayerGradOp( + graph->OriginProgram(), pylayer_ops, pylayer_grad_ops); + } + + for (auto op_hander : all_ops) { + auto *compute_op = + dynamic_cast(op_hander); + if (compute_op == nullptr) continue; + if (compute_op->Name() == "pylayer" || + compute_op->Name() == "pylayer_grad") { + ir::Node *op_node = op_hander->Node(); + auto *op_base = compute_op->GetOp(); + if (op_base->Attrs().count("skip_eager_deletion_vars")) { + op_node->Op()->SetAttr( + "skip_eager_deletion_vars", + op_base->Attrs().at("skip_eager_deletion_vars")); + } + } + } + } +}; + +} // namespace ir +} // namespace framework +} // namespace paddle + +REGISTER_PASS(pylayer_op_eager_deletion_pass, + paddle::framework::ir::PyLayerOpEagerDeletionPass); diff --git a/paddle/fluid/framework/ir/remove_padding_recover_padding_pass.cc b/paddle/fluid/framework/ir/remove_padding_recover_padding_pass.cc index b404d023d487b..704f59bbace67 100644 --- a/paddle/fluid/framework/ir/remove_padding_recover_padding_pass.cc +++ b/paddle/fluid/framework/ir/remove_padding_recover_padding_pass.cc @@ -23,15 +23,18 @@ namespace framework { namespace ir { namespace patterns { void EmbEltwiseLayernorm::operator()() { - // Create nodes for fused_embedding_eltwise_layernorm. - auto* emb_elt_layernorm_op = - pattern->NewNode(emb_elt_layernorm_op_repr()) - ->assert_is_op("fused_embedding_eltwise_layernorm"); + // Create nodes for fused_embedding_eltwise_layernorm or + // prompt_tuning_emb_eltwise_layernorm. + std::unordered_set embedding_ops{ + "fused_embedding_eltwise_layernorm", + "prompt_tuning_emb_eltwise_layernorm"}; + auto* emb_elt_layernorm_op = pattern->NewNode(emb_elt_layernorm_op_repr()) + ->assert_is_ops(embedding_ops); auto* emb_elt_layernorm_out = pattern->NewNode(emb_elt_layernorm_out_repr()) - ->assert_is_op_output("fused_embedding_eltwise_layernorm", "Out"); + ->assert_is_ops_output(embedding_ops, "Out"); - // Add links for fused_embedding_eltwise_layernorm op. + // Add links for embedding_ops. emb_elt_layernorm_op->LinksTo({emb_elt_layernorm_out}); } diff --git a/paddle/fluid/framework/ir/trt_prompt_tuning_embedding_eltwise_layernorm_fuse_pass.cc b/paddle/fluid/framework/ir/trt_prompt_tuning_embedding_eltwise_layernorm_fuse_pass.cc new file mode 100644 index 0000000000000..6bdd56dff2087 --- /dev/null +++ b/paddle/fluid/framework/ir/trt_prompt_tuning_embedding_eltwise_layernorm_fuse_pass.cc @@ -0,0 +1,596 @@ +// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/fluid/framework/ir/trt_prompt_tuning_embedding_eltwise_layernorm_fuse_pass.h" + +#include + +#include "paddle/fluid/framework/op_version_registry.h" + +namespace paddle { +namespace framework { +namespace ir { +class Node; +} // namespace ir +} // namespace framework +} // namespace paddle + +namespace paddle { +namespace framework { +namespace ir { +namespace patterns { + +static PDNode* create_emb_vars(PDPattern* pattern, + const std::string& name, + const std::string& arg, + bool is_persist = false) { + std::unordered_set embedding_ops{"lookup_table", + "lookup_table_v2"}; + PDNode* node = + pattern->NewNode(name)->assert_is_ops_input(embedding_ops, arg); + if (is_persist) return node->assert_is_persistable_var(); + return node; +} +static PDNode* create_emb_out_vars(PDPattern* pattern, + const std::string& name, + const std::string& arg) { + std::unordered_set embedding_ops{"lookup_table", + "lookup_table_v2"}; + PDNode* node = pattern->NewNode(name) + ->assert_is_only_output_of_ops(embedding_ops) + ->assert_is_op_input("elementwise_add", arg) + ->AsIntermediate(); + return node; +} +void TrtPromptTuningEmbedding2Eltwise1Pattern::operator()() { + auto* lookup_table1_x = + create_emb_vars(pattern, lookup_table1_x_repr(), "Ids"); + auto* lookup_table2_x = + create_emb_vars(pattern, lookup_table2_x_repr(), "Ids"); + auto* lookup_table1_w = + create_emb_vars(pattern, lookup_table1_w_repr(), "W", true); + auto* lookup_table2_w = + create_emb_vars(pattern, lookup_table2_w_repr(), "W", true); + std::unordered_set embedding_ops{"lookup_table", + "lookup_table_v2"}; + + auto* lookup_table1 = + pattern->NewNode(lookup_table1_repr())->assert_is_ops(embedding_ops); + auto* lookup_table2 = + pattern->NewNode(lookup_table2_repr())->assert_is_ops(embedding_ops); + auto* lookup_table1_out = + create_emb_out_vars(pattern, lookup_table1_out_repr(), "X"); + auto* lookup_table2_out = + create_emb_out_vars(pattern, lookup_table2_out_repr(), "Y"); + auto* eltwise_add = + pattern->NewNode(eltwise_add_repr())->assert_is_op("elementwise_add"); + auto* eltwise_add_out = pattern->NewNode(eltwise_add_out_repr()) + ->assert_is_op_output("elementwise_add"); + lookup_table1->LinksFrom({lookup_table1_x, lookup_table1_w}) + .LinksTo({lookup_table1_out}); + lookup_table2->LinksFrom({lookup_table2_x, lookup_table2_w}) + .LinksTo({lookup_table2_out}); + eltwise_add->LinksFrom({lookup_table1_out, lookup_table2_out}) + .LinksTo({eltwise_add_out}); +} +void TrtPromptTuningEmbedding1Eltwise1Pattern::operator()() { + auto* lookup_table1_x = + create_emb_vars(pattern, lookup_table1_x_repr(), "Ids"); + auto* lookup_table1_w = + create_emb_vars(pattern, lookup_table1_w_repr(), "W", true); + std::unordered_set embedding_ops{"lookup_table", + "lookup_table_v2"}; + + auto* lookup_table1 = + pattern->NewNode(lookup_table1_repr())->assert_is_ops(embedding_ops); + auto* lookup_table1_out = + create_emb_out_vars(pattern, lookup_table1_out_repr(), "Y"); + auto* eltwise_add = + pattern->NewNode(eltwise_add_repr())->assert_is_op("elementwise_add"); + auto* eltwise_add_in = pattern->NewNode(eltwise_add_in_repr()) + ->assert_is_op_input("elementwise_add", "X") + ->assert_is_op_output("elementwise_add"); + auto* eltwise_add_out = pattern->NewNode(eltwise_add_out_repr()) + ->assert_is_op_output("elementwise_add"); + lookup_table1->LinksFrom({lookup_table1_x, lookup_table1_w}) + .LinksTo({lookup_table1_out}); + eltwise_add->LinksFrom({lookup_table1_out, eltwise_add_in}) + .LinksTo({eltwise_add_out}); +} +void TrtPromptTuningSkipLayerNorm::operator()() { + auto* eltwise_add = + pattern->NewNode(eltwise_add_repr())->assert_is_op("elementwise_add"); + auto* eltwise_add_out = pattern->NewNode(eltwise_add_out_repr()) + ->assert_is_op_output("elementwise_add") + ->AsIntermediate(); + + auto* mul0_x = pattern->NewNode(mul0_x_repr()) + ->assert_is_op_input("matrix_multiply", "X"); + + auto* mul0_y = pattern->NewNode(mul0_y_repr()) + ->assert_is_op_input("matrix_multiply", "Y"); + + auto* mul0 = pattern->NewNode(mul0_repr())->assert_is_op("matrix_multiply"); + + auto* mul0_out = pattern->NewNode(mul0_out_repr()) + ->assert_is_op_output("matrix_multiply") + ->assert_is_op_input("elementwise_add", "X") + ->AsIntermediate(); + + auto* eltadd0_b = pattern->NewNode(eltadd0_b_repr()) + ->assert_is_op_input("elementwise_add", "Y"); + + auto* eltadd0 = + pattern->NewNode(eltadd0_repr())->assert_is_op("elementwise_add"); + + auto* eltadd0_out = pattern->NewNode(eltadd0_out_repr()) + ->assert_is_op_output("elementwise_add") + ->assert_is_op_input("relu") + ->AsIntermediate(); + + auto* relu = pattern->NewNode(relu_repr())->assert_is_op("relu"); + auto* relu_out = pattern->NewNode(relu_out_repr()) + ->assert_is_op_output("relu") + ->assert_is_op_input("matrix_multiply", "X") + ->AsIntermediate(); + + auto* mul1_y = pattern->NewNode(mul1_y_repr()) + ->assert_is_op_input("matrix_multiply", "Y"); + auto* mul1 = pattern->NewNode(mul1_repr())->assert_is_op("matrix_multiply"); + + auto* mul1_out = pattern->NewNode(mul1_out_repr()) + ->assert_is_op_output("matrix_multiply") + ->assert_is_op_input("elementwise_add", "X") + ->AsIntermediate(); + + auto* eltadd1_b = pattern->NewNode(eltadd1_b_repr()) + ->assert_is_op_input("elementwise_add", "Y"); + + auto* eltadd1 = + pattern->NewNode(eltadd1_repr())->assert_is_op("elementwise_add"); + + auto* eltadd1_out = pattern->NewNode(eltadd1_out_repr()) + ->assert_is_op_output("elementwise_add"); + + auto* concat = pattern->NewNode(concat_repr())->assert_is_op("concat"); + + auto* concat_out = pattern->NewNode(concat_out_repr()) + ->assert_is_op_output("concat") + ->assert_is_op_input("layer_norm", "X") + ->AsIntermediate(); + auto* layer_norm = + pattern->NewNode(layer_norm_repr())->assert_is_op("layer_norm"); + auto* layer_norm_out = pattern->NewNode(layer_norm_out_repr()) + ->assert_is_op_output("layer_norm", "Y") + ->AsOutput(); + auto* layer_norm_bias_var = pattern->NewNode(layer_norm_bias_repr()) + ->AsInput() + ->assert_is_persistable_var() + ->assert_is_op_input("layer_norm", "Bias"); + auto* layer_norm_scale_var = pattern->NewNode(layer_norm_scale_repr()) + ->AsInput() + ->assert_is_persistable_var() + ->assert_is_op_input("layer_norm", "Scale"); + + eltwise_add->LinksTo({eltwise_add_out}); + + mul0->LinksFrom({mul0_x, mul0_y}).LinksTo({mul0_out}); + + eltadd0->LinksFrom({mul0_out, eltadd0_b}).LinksTo({eltadd0_out}); + + relu->LinksFrom({eltadd0_out}).LinksTo({relu_out}); + + mul1->LinksFrom({relu_out, mul1_y}).LinksTo({mul1_out}); + + eltadd1->LinksFrom({mul1_out, eltadd1_b}).LinksTo({eltadd1_out}); + + concat->LinksFrom({eltadd1_out, eltwise_add_out}).LinksTo({concat_out}); + + layer_norm->LinksFrom({concat_out, layer_norm_bias_var, layer_norm_scale_var}) + .LinksTo({layer_norm_out}); +} + +} // namespace patterns + +int TrtPromptTuningEmbeddingEltwiseLayerNormFusePass::BuildFusion( + Graph* graph, const std::string& name_scope + /*const Scope* scope*/) const { + GraphPatternDetector gpd; + auto* pattern = gpd.mutable_pattern(); + std::string pos_id = Get("tensorrt_transformer_posid"); + std::string mask_id = Get("tensorrt_transformer_maskid"); + std::vector>> start_pattern_in_nodes; + std::vector start_pattern_out_node; + std::vector> start_pattern_remove_nodes; + + // Create pattern. + patterns::TrtPromptTuningEmbedding2Eltwise1Pattern start_pattern( + pattern, name_scope + "/start"); + start_pattern(); + auto handler = [&](const GraphPatternDetector::subgraph_t& subgraph, + Graph* g) { + GET_IR_NODE_FROM_SUBGRAPH(lookup_table1_x, lookup_table1_x, start_pattern); + GET_IR_NODE_FROM_SUBGRAPH(lookup_table2_x, lookup_table2_x, start_pattern); + GET_IR_NODE_FROM_SUBGRAPH(lookup_table1_w, lookup_table1_w, start_pattern); + GET_IR_NODE_FROM_SUBGRAPH(lookup_table2_w, lookup_table2_w, start_pattern); + GET_IR_NODE_FROM_SUBGRAPH(lookup_table1, lookup_table1, start_pattern); + GET_IR_NODE_FROM_SUBGRAPH(lookup_table2, lookup_table2, start_pattern); + GET_IR_NODE_FROM_SUBGRAPH( + lookup_table1_out, lookup_table1_out, start_pattern); + GET_IR_NODE_FROM_SUBGRAPH( + lookup_table2_out, lookup_table2_out, start_pattern); + GET_IR_NODE_FROM_SUBGRAPH(eltwise_add, eltwise_add, start_pattern); + GET_IR_NODE_FROM_SUBGRAPH(eltwise_add_out, eltwise_add_out, start_pattern); + + if (!IsCompat(subgraph, graph)) { + LOG(WARNING) << "Pass(TrtPromptTuningEmbedding2Eltwise1Pattern) in op " + "compat failed."; + return; + } + std::vector> ins; + ins.push_back(std::make_pair(lookup_table1_x, lookup_table1_w)); + ins.push_back(std::make_pair(lookup_table2_x, lookup_table2_w)); + start_pattern_in_nodes.push_back(ins); + start_pattern_out_node.push_back(eltwise_add_out); + + std::unordered_set rm_nodes; + rm_nodes.insert({lookup_table1, + lookup_table2, + lookup_table1_out, + lookup_table2_out, + eltwise_add, + eltwise_add_out}); + start_pattern_remove_nodes.push_back(rm_nodes); + }; + gpd(graph, handler); + + std::vector> inner_pattern_ins; + std::vector inner_pattern_tmp_in; + std::vector inner_pattern_out; + std::vector> inner_pattern_remove_nodes; + + GraphPatternDetector gpd2; + auto* pattern2 = gpd2.mutable_pattern(); + patterns::TrtPromptTuningEmbedding1Eltwise1Pattern second_pattern( + pattern2, name_scope + "/second"); + second_pattern(); + auto handler2 = [&](const GraphPatternDetector::subgraph_t& subgraph, + Graph* g) { + GET_IR_NODE_FROM_SUBGRAPH(lookup_table1_x, lookup_table1_x, second_pattern); + GET_IR_NODE_FROM_SUBGRAPH(lookup_table1_w, lookup_table1_w, second_pattern); + GET_IR_NODE_FROM_SUBGRAPH(lookup_table1, lookup_table1, second_pattern); + GET_IR_NODE_FROM_SUBGRAPH( + lookup_table1_out, lookup_table1_out, second_pattern); + GET_IR_NODE_FROM_SUBGRAPH(eltwise_add_in, eltwise_add_in, second_pattern); + GET_IR_NODE_FROM_SUBGRAPH(eltwise_add, eltwise_add, second_pattern); + GET_IR_NODE_FROM_SUBGRAPH(eltwise_add_out, eltwise_add_out, second_pattern); + if (!IsCompat(subgraph, graph)) { + LOG(WARNING) << "Pass(TrtPromptTuningEmbedding1Eltwise1Pattern) in op " + "compat failed."; + return; + } + auto in = std::make_pair(lookup_table1_x, lookup_table1_w); + inner_pattern_ins.push_back(in); + inner_pattern_tmp_in.push_back(eltwise_add_in); + inner_pattern_out.push_back(eltwise_add_out); + + std::unordered_set rm_nodes; + rm_nodes.insert( + {lookup_table1, lookup_table1_out, eltwise_add, eltwise_add_out}); + inner_pattern_remove_nodes.push_back(rm_nodes); + }; + gpd2(graph, handler2); + + std::vector end_pattern_elt_out; + std::vector end_pattern_eltadd1; + std::vector end_pattern_eltadd1_out; + std::vector end_pattern_concat; + std::vector end_pattern_concat_out; + std::vector end_pattern_scales; + std::vector end_pattern_biases; + std::vector end_pattern_out; + std::vector end_patter_layernorms; + std::vector> end_pattern_remove_nodes; + GraphPatternDetector gpd3; + auto* pattern3 = gpd3.mutable_pattern(); + patterns::TrtPromptTuningSkipLayerNorm skip_layernorm_pattern( + pattern3, name_scope + "/third"); + skip_layernorm_pattern(); + auto handler3 = [&](const GraphPatternDetector::subgraph_t& subgraph, + Graph* g) { + GET_IR_NODE_FROM_SUBGRAPH(eltwise_add, eltwise_add, skip_layernorm_pattern); + GET_IR_NODE_FROM_SUBGRAPH( + eltwise_add_out, eltwise_add_out, skip_layernorm_pattern); + + GET_IR_NODE_FROM_SUBGRAPH(eltadd1, eltadd1, skip_layernorm_pattern); + GET_IR_NODE_FROM_SUBGRAPH(eltadd1_out, eltadd1_out, skip_layernorm_pattern); + GET_IR_NODE_FROM_SUBGRAPH(concat, concat, skip_layernorm_pattern); + GET_IR_NODE_FROM_SUBGRAPH(concat_out, concat_out, skip_layernorm_pattern); + GET_IR_NODE_FROM_SUBGRAPH(layer_norm, layer_norm, skip_layernorm_pattern); + GET_IR_NODE_FROM_SUBGRAPH( + layer_norm_out, layer_norm_out, skip_layernorm_pattern); + GET_IR_NODE_FROM_SUBGRAPH( + layer_norm_bias, layer_norm_bias, skip_layernorm_pattern); + GET_IR_NODE_FROM_SUBGRAPH( + layer_norm_scale, layer_norm_scale, skip_layernorm_pattern); + if (!IsCompat(subgraph, graph)) { + LOG(WARNING) << "Pass(TrtPromptTuningSkipLayerNorm) in op compat failed."; + return; + } + end_pattern_elt_out.push_back(eltwise_add_out); + std::unordered_set rm_nodes; + rm_nodes.insert({concat}); + rm_nodes.insert({concat_out}); + rm_nodes.insert({layer_norm}); + end_pattern_remove_nodes.push_back(rm_nodes); + + end_pattern_eltadd1.push_back(eltadd1); + end_pattern_eltadd1_out.push_back(eltadd1_out); + end_pattern_concat.push_back(concat); + end_pattern_concat_out.push_back(concat_out); + end_pattern_biases.push_back(layer_norm_bias); + end_pattern_scales.push_back(layer_norm_scale); + end_pattern_out.push_back(layer_norm_out); + end_patter_layernorms.push_back(layer_norm); + }; + gpd3(graph, handler3); + + if (start_pattern_in_nodes.empty() || end_pattern_elt_out.empty()) { + return 0; + } + // only reserve the subgraphs that in connected domains. + int fusion_count = 0; + // fusion_id for (i, k, js) + std::vector>>> + fusion_ids; + for (size_t i = 0; i < start_pattern_in_nodes.size(); ++i) { + Node* tmp = start_pattern_out_node[i]; + Node* old_tmp = nullptr; + // get correct inner pattern node order. + std::vector js; + while (tmp != old_tmp) { + old_tmp = tmp; + for (size_t j = 0; j < inner_pattern_tmp_in.size(); ++j) { + if (inner_pattern_tmp_in[j] == tmp) { + tmp = inner_pattern_out[j]; + js.push_back(j); + break; + } + } + } + + for (size_t k = 0; k < end_pattern_elt_out.size(); ++k) { + if (tmp == end_pattern_elt_out[k]) { + fusion_ids.push_back(std::make_pair(i, std::make_pair(k, js))); + break; + } + } + } + + for (auto& fusion_id : fusion_ids) { + int i = fusion_id.first; + int k = fusion_id.second.first; + std::vector js = fusion_id.second.second; + + std::vector ids; + std::vector embs; + + auto ids0_shape = start_pattern_in_nodes[i][0].first->Var()->GetShape(); + bool flag = true; + for (auto& item : start_pattern_in_nodes[i]) { + auto ids_shape = item.first->Var()->GetShape(); + if (ids_shape.size() != ids0_shape.size()) { + VLOG(3) << "Shape check failed, ids'rank are not all equal, stop " + "trt_prompt_tuning_embedding_eltwise_layernorm_fuse_pass."; + flag = false; + } else { + for (size_t j = 0; j < ids_shape.size(); ++j) { + if (ids_shape[j] != ids0_shape[j]) { + VLOG(3) + << "Shape check failed, ids.shape[i] are not all equal, stop " + "trt_prompt_tuning_embedding_eltwise_layernorm_fuse_pass."; + flag = false; + } + } + } + ids.push_back(item.first->Name()); + embs.push_back(item.second->Name()); + } + for (auto item : js) { + auto ids_shape = inner_pattern_ins[item].first->Var()->GetShape(); + if (ids_shape.size() != ids0_shape.size()) { + VLOG(3) << "Shape check failed, ids'rank are not all equal, stop " + "trt_prompt_tuning_embedding_eltwise_layernorm_fuse_pass."; + flag = false; + } else { + for (size_t j = 0; j < ids_shape.size(); ++j) { + if (ids_shape[j] != ids0_shape[j]) { + VLOG(3) + << "Shape check failed, ids.shape[i] are not all equal, stop " + "trt_prompt_tuning_embedding_eltwise_layernorm_fuse_pass."; + flag = false; + } + } + } + ids.push_back(inner_pattern_ins[item].first->Name()); + embs.push_back(inner_pattern_ins[item].second->Name()); + } + + if (flag) { + OpDesc new_op_desc(end_patter_layernorms[0]->Op()->Block()); + new_op_desc.SetType("prompt_tuning_emb_eltwise_layernorm"); + new_op_desc.SetInput("Ids", ids); + new_op_desc.SetInput("Embs", embs); + new_op_desc.SetInput("PosId", {pos_id}); + new_op_desc.SetInput("MaskId", {mask_id}); + + new_op_desc.SetInput("Bias", {end_pattern_biases[k]->Name()}); + new_op_desc.SetInput("Scale", {end_pattern_scales[k]->Name()}); + new_op_desc.SetInput("DenseVector", {end_pattern_eltadd1_out[k]->Name()}); + new_op_desc.SetOutput("Out", {end_pattern_out[k]->Name()}); + new_op_desc.SetAttr("epsilon", + end_patter_layernorms[k]->Op()->GetAttr("epsilon")); + + if (end_patter_layernorms[k]->Op()->HasAttr("out_threshold")) { + new_op_desc.SetAttr("enable_int8", true); + new_op_desc.SetAttr( + "out_threshold", + end_patter_layernorms[k]->Op()->GetAttr("out_threshold")); + } + + auto* embedding_eltwise_layernorm = graph->CreateOpNode(&new_op_desc); + + for (auto& item : start_pattern_in_nodes[i]) { + IR_NODE_LINK_TO(item.first, embedding_eltwise_layernorm); + IR_NODE_LINK_TO(item.second, embedding_eltwise_layernorm); + } + for (auto item : js) { + IR_NODE_LINK_TO(inner_pattern_ins[item].first, + embedding_eltwise_layernorm); + IR_NODE_LINK_TO(inner_pattern_ins[item].second, + embedding_eltwise_layernorm); + } + IR_NODE_LINK_TO(end_pattern_biases[k], embedding_eltwise_layernorm); + IR_NODE_LINK_TO(end_pattern_scales[k], embedding_eltwise_layernorm); + IR_NODE_LINK_TO(end_pattern_eltadd1_out[k], embedding_eltwise_layernorm); + IR_NODE_LINK_TO(embedding_eltwise_layernorm, end_pattern_out[k]); + + // Remove unneeded nodes. + std::unordered_set marked_nodes; + marked_nodes.insert(start_pattern_remove_nodes[i].begin(), + start_pattern_remove_nodes[i].end()); + marked_nodes.insert(end_pattern_remove_nodes[k].begin(), + end_pattern_remove_nodes[k].end()); + for (auto item : js) { + marked_nodes.insert(inner_pattern_remove_nodes[item].begin(), + inner_pattern_remove_nodes[item].end()); + } + GraphSafeRemoveNodes(graph, marked_nodes); + ++fusion_count; + } else { + VLOG(3) << "Shape check failed, stop " + "trt_prompt_tuning_embedding_eltwise_layernorm_fuse_pass."; + } + } + return fusion_count; +} + +TrtPromptTuningEmbeddingEltwiseLayerNormFusePass:: + TrtPromptTuningEmbeddingEltwiseLayerNormFusePass() { + AddOpCompat(OpCompat("elementwise_add")) + .AddInput("X") + .IsTensor() + .End() + .AddInput("Y") + .IsTensor() + .End() + .AddOutput("Out") + .IsTensor() + .End() + .AddAttr("axis") + .End(); + + AddOpCompat(OpCompat("relu")) + .AddInput("X") + .IsTensor() + .End() + .AddOutput("Out") + .IsTensor() + .End(); + + AddOpCompat(OpCompat("concat")) + .AddInput("X") + .End() + .AddOutput("Out") + .IsTensor() + .End() + .AddAttr("axis") + .End(); + + AddOpCompat(OpCompat("layer_norm")) + .AddInput("X") + .IsTensor() + .End() + .AddInput("Scale") + .IsTensor() + .End() + .AddInput("Bias") + .IsTensor() + .End() + .AddOutput("Y") + .IsTensor() + .End() + .AddOutput("Mean") + .IsTensor() + .End() + .AddOutput("Variance") + .IsTensor() + .End() + .AddAttr("epsilon") + .IsNumGE(0.0f) + .IsNumLE(0.001f) + .End() + .AddAttr("begin_norm_axis") + .IsNumGT(0) + .End(); +} + +void TrtPromptTuningEmbeddingEltwiseLayerNormFusePass::ApplyImpl( + Graph* graph) const { + bool with_dynamic_shape = Get("with_dynamic_shape"); + if (!with_dynamic_shape) { + VLOG(3) << "Stop this pass, because " + "trt_prompt_tuning_embedding_eltwise_layernorm_fuse_pass need: " + "use_varseqlen, " + "with_dynamic_shape." + "please reconfig."; + return; + } + FusePassBase::Init(name_scope_, graph); + int fusion_count = + TrtPromptTuningEmbeddingEltwiseLayerNormFusePass::BuildFusion( + graph, name_scope_); + if (fusion_count > 0) { + bool use_varseqlen = Get("use_varseqlen"); + std::string pos_id = Get("tensorrt_transformer_posid"); + std::string mask_id = Get("tensorrt_transformer_maskid"); + + if ((use_varseqlen && !pos_id.empty() && !mask_id.empty())) { + VLOG(3) + << "start trt_prompt_tuning_embedding_eltwise_layernorm_fuse_pass"; + } else { + VLOG(3) << "Stop this pass, because " + "trt_prompt_tuning_embedding_eltwise_layernorm_fuse_pass only " + "support use_varseqlen, please reconfig"; + return; + } + graph->Set(kEmbEltwiseLayernormPass, new bool(true)); + } + AddStatis(fusion_count); +} + +} // namespace ir +} // namespace framework +} // namespace paddle + +REGISTER_PASS( + trt_prompt_tuning_embedding_eltwise_layernorm_fuse_pass, + paddle::framework::ir::TrtPromptTuningEmbeddingEltwiseLayerNormFusePass); +REGISTER_PASS_CAPABILITY( + trt_prompt_tuning_embedding_eltwise_layernorm_fuse_pass) + .AddCombination( + paddle::framework::compatible::OpVersionComparatorCombination() + .LE("lookup_table", 1) + .LE("lookup_table_v2", 1) + .LE("elementweise_add", 1)); diff --git a/paddle/fluid/framework/ir/trt_prompt_tuning_embedding_eltwise_layernorm_fuse_pass.h b/paddle/fluid/framework/ir/trt_prompt_tuning_embedding_eltwise_layernorm_fuse_pass.h new file mode 100644 index 0000000000000..16fd38b1abed6 --- /dev/null +++ b/paddle/fluid/framework/ir/trt_prompt_tuning_embedding_eltwise_layernorm_fuse_pass.h @@ -0,0 +1,118 @@ +// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include +#include +#include + +#include "paddle/fluid/framework/ir/fuse_pass_base.h" + +namespace paddle { +namespace framework { +namespace ir { +class Graph; +} // namespace ir +} // namespace framework +} // namespace paddle + +namespace paddle { +namespace framework { +namespace ir { +namespace patterns { + +struct TrtPromptTuningEmbedding2Eltwise1Pattern : public PatternBase { + TrtPromptTuningEmbedding2Eltwise1Pattern(PDPattern* pattern, + const std::string& name_scope) + : PatternBase(pattern, name_scope, "embedding2_eltwise1") {} + + void operator()(); + PATTERN_DECL_NODE(feed1); + PATTERN_DECL_NODE(feed2); + PATTERN_DECL_NODE(lookup_table1_x); + PATTERN_DECL_NODE(lookup_table2_x); + PATTERN_DECL_NODE(lookup_table1_w); + PATTERN_DECL_NODE(lookup_table2_w); + PATTERN_DECL_NODE(lookup_table1); + PATTERN_DECL_NODE(lookup_table2); + PATTERN_DECL_NODE(lookup_table1_out); + PATTERN_DECL_NODE(lookup_table2_out); + PATTERN_DECL_NODE(eltwise_add); + PATTERN_DECL_NODE(eltwise_add_out); +}; + +struct TrtPromptTuningEmbedding1Eltwise1Pattern : public PatternBase { + TrtPromptTuningEmbedding1Eltwise1Pattern(PDPattern* pattern, + const std::string& name_scope) + : PatternBase(pattern, name_scope, "embedding1_eltwise1") {} + void operator()(); + PATTERN_DECL_NODE(feed1); + PATTERN_DECL_NODE(lookup_table1_x); + PATTERN_DECL_NODE(lookup_table1_w); + PATTERN_DECL_NODE(lookup_table1); + PATTERN_DECL_NODE(lookup_table1_out); + PATTERN_DECL_NODE(eltwise_add_in); + PATTERN_DECL_NODE(eltwise_add); + PATTERN_DECL_NODE(eltwise_add_out); +}; + +struct TrtPromptTuningSkipLayerNorm : public PatternBase { + TrtPromptTuningSkipLayerNorm(PDPattern* pattern, + const std::string& name_scope) + : PatternBase(pattern, name_scope, "skip_layernorm") {} + void operator()(); + + PATTERN_DECL_NODE(eltwise_add); + PATTERN_DECL_NODE(eltwise_add_out); + PATTERN_DECL_NODE(mul0_x); + PATTERN_DECL_NODE(mul0_y); + PATTERN_DECL_NODE(mul0); + PATTERN_DECL_NODE(mul0_out); + PATTERN_DECL_NODE(eltadd0_b); + PATTERN_DECL_NODE(eltadd0); + PATTERN_DECL_NODE(eltadd0_out); + PATTERN_DECL_NODE(relu); + PATTERN_DECL_NODE(relu_out); + PATTERN_DECL_NODE(mul1_y); + PATTERN_DECL_NODE(mul1); + PATTERN_DECL_NODE(mul1_out); + PATTERN_DECL_NODE(eltadd1_b); + PATTERN_DECL_NODE(eltadd1); + PATTERN_DECL_NODE(eltadd1_out); + PATTERN_DECL_NODE(concat); + PATTERN_DECL_NODE(concat_out); + PATTERN_DECL_NODE(layer_norm); + PATTERN_DECL_NODE(layer_norm_bias); + PATTERN_DECL_NODE(layer_norm_scale); + PATTERN_DECL_NODE(layer_norm_out); +}; +} // namespace patterns + +class TrtPromptTuningEmbeddingEltwiseLayerNormFusePass : public FusePassBase { + public: + TrtPromptTuningEmbeddingEltwiseLayerNormFusePass(); + virtual ~TrtPromptTuningEmbeddingEltwiseLayerNormFusePass() {} + + protected: + void ApplyImpl(Graph* graph) const; + int BuildFusion(Graph* graph, const std::string& name_scope + /*const Scope* scope*/) const; + const std::string name_scope_{ + "trt_prompt_tuning_embedding_eltwise_layernorm_fuse"}; +}; + +} // namespace ir +} // namespace framework +} // namespace paddle diff --git a/paddle/fluid/framework/ir/xpu/conv2d_bias_fuse_pass.cc b/paddle/fluid/framework/ir/xpu/conv2d_bias_fuse_pass.cc new file mode 100644 index 0000000000000..7f53507a85c83 --- /dev/null +++ b/paddle/fluid/framework/ir/xpu/conv2d_bias_fuse_pass.cc @@ -0,0 +1,339 @@ +// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/fluid/framework/ir/xpu/conv2d_bias_fuse_pass.h" + +#include "glog/logging.h" + +#include "paddle/fluid/framework/ir/fuse_pass_base.h" +#include "paddle/fluid/framework/ir/graph_pattern_detector.h" +#include "paddle/fluid/framework/ir/pass.h" +#include "paddle/fluid/framework/ir/xpu/pass_utils.h" +#include "paddle/fluid/framework/op_version_registry.h" +#include "paddle/fluid/platform/enforce.h" + +namespace paddle { +namespace framework { +namespace ir { +namespace patterns { + +struct FcBiasPattern : public PatternBase { + FcBiasPattern(PDPattern* pattern, + const std::string& name_scope, + const std::string& mul_type); + + // declare operator node's name + PATTERN_DECL_NODE(ew_bias_add); + // declare variable node's name + PATTERN_DECL_NODE(mul_out); + PATTERN_DECL_NODE(ew_bias_add_x); + PATTERN_DECL_NODE(ew_bias_add_out); + + private: + std::string mul_type_; +}; + +FcBiasPattern::FcBiasPattern(PDPattern* pattern, + const std::string& name_scope, + const std::string& mul_type) + : PatternBase(pattern, name_scope, name_scope), mul_type_(mul_type) { + auto* mul_out = pattern->NewNode(mul_out_repr()) + ->assert_is_op_output(mul_type_, "Out") + ->assert_is_op_input("elementwise_add", "Y") + ->assert_has_n_outputs(1); + auto* ew_bias_add = pattern->NewNode(ew_bias_add_repr()) + ->assert_is_op("elementwise_add") + ->assert_more([](Node* node) { + auto* op_desc = node->Op(); + auto axis = op_desc->GetAttrIfExists("axis"); + return axis == -1; + }); + auto* ew_bias_add_x = pattern->NewNode(ew_bias_add_x_repr()) + ->assert_is_op_input("elementwise_add", "X") + ->assert_is_persistable_var() + ->assert_has_n_outputs(1); + auto* ew_bias_add_out = pattern->NewNode(ew_bias_add_out_repr()) + ->assert_is_op_output("elementwise_add", "Out"); + ew_bias_add->LinksFrom({mul_out, ew_bias_add_x}).LinksTo({ew_bias_add_out}); +} + +struct Conv2dBiasPattern : public PatternBase { + Conv2dBiasPattern(PDPattern* pattern, const std::string& name_scope); + + // declare operator node's name + PATTERN_DECL_NODE(ew_bias_add); + // declare variable node's name + PATTERN_DECL_NODE(x); + PATTERN_DECL_NODE(ew_bias_add_y); + PATTERN_DECL_NODE(ew_bias_add_out); +}; + +Conv2dBiasPattern::Conv2dBiasPattern(PDPattern* pattern, + const std::string& name_scope) + : PatternBase(pattern, name_scope, name_scope) { + auto* x = pattern->NewNode(x_repr()) + ->assert_is_op_output("conv2d", "Output") + ->assert_is_op_input("elementwise_add", "X") + ->assert_has_n_outputs(1); + auto* ew_bias_add = pattern->NewNode(ew_bias_add_repr()) + ->assert_is_op("elementwise_add") + ->assert_more([](Node* node) { + auto* op_desc = node->Op(); + auto axis = op_desc->GetAttrIfExists("axis"); + return axis == -1; + }); + auto* ew_bias_add_y = pattern->NewNode(ew_bias_add_y_repr()) + ->assert_is_op_input("elementwise_add", "Y") + ->assert_is_persistable_var() + ->assert_has_n_outputs(1) + ->assert_more([](Node* node) { + auto y_shape = node->Var()->GetShape(); + size_t y_rank = y_shape.size(); + return y_rank == 4 && y_shape[0] == 1 && + y_shape[2] == 1 && y_shape[3] == 1; + }); + auto* ew_bias_add_out = pattern->NewNode(ew_bias_add_out_repr()) + ->assert_is_op_output("elementwise_add", "Out"); + ew_bias_add->LinksFrom({x, ew_bias_add_y}).LinksTo({ew_bias_add_out}); +} + +struct ScaleFusePattern : public PatternBase { + ScaleFusePattern(PDPattern* pattern, const std::string& name_scope); + // declare operator node's name + PATTERN_DECL_NODE(ele_mul); + PATTERN_DECL_NODE(ele_add); + // declare variable node's name + PATTERN_DECL_NODE(x); + PATTERN_DECL_NODE(ele_mul_y); + PATTERN_DECL_NODE(ele_mul_out); + PATTERN_DECL_NODE(ele_add_y); + PATTERN_DECL_NODE(ele_add_out); +}; + +ScaleFusePattern::ScaleFusePattern(PDPattern* pattern, + const std::string& name_scope) + : PatternBase(pattern, name_scope, name_scope) { + // ele_mul op + auto ele_mul = + pattern->NewNode(ele_mul_repr())->assert_is_op("elementwise_mul"); + auto x = pattern->NewNode(x_repr()) + ->assert_is_op_input("elementwise_mul", "X") + ->AsInput(); + auto ele_mul_y = pattern->NewNode(ele_mul_y_repr()) + ->assert_is_op_input("elementwise_mul", "Y") + ->assert_is_persistable_var() + ->assert_has_n_outputs(1) + ->assert_more([](Node* node) { + return node->Var()->GetShape().size() == 1; + }); + auto ele_mul_out = pattern->NewNode(ele_mul_out_repr()) + ->assert_is_op_output("elementwise_mul", "Out") + ->assert_is_op_input("elementwise_add", "X") + ->assert_has_n_outputs(1); + ele_mul->LinksFrom({x, ele_mul_y}).LinksTo({ele_mul_out}); + // ele_add op + auto ele_add = + pattern->NewNode(ele_add_repr())->assert_is_op("elementwise_add"); + auto ele_add_y = pattern->NewNode(ele_add_y_repr()) + ->assert_is_op_input("elementwise_add", "Y") + ->assert_is_persistable_var() + ->assert_has_n_outputs(1) + ->assert_more([](Node* node) { + return node->Var()->GetShape().size() == 1; + }); + auto ele_add_out = pattern->NewNode(ele_add_out_repr()) + ->assert_is_op_output("elementwise_add", "Out"); + ele_add->LinksFrom({ele_mul_out, ele_add_y}).LinksTo({ele_add_out}); +} + +} // namespace patterns + +void Conv2dBiasFusePass::TransFcBias(ir::Graph* graph, + const std::string& mul_type) const { + PADDLE_ENFORCE_NOT_NULL( + graph, platform::errors::PreconditionNotMet("graph should not be null.")); + Init(name_scope_, graph); + GraphPatternDetector gpd; + patterns::FcBiasPattern pattern(gpd.mutable_pattern(), name_scope_, mul_type); + + int found_subgraph_count = 0; + auto handler = [&](const GraphPatternDetector::subgraph_t& subgraph, + Graph* graph) { + VLOG(4) << "handle TransFcBias fuse"; + // declare operator node's name + GET_IR_NODE(ew_bias_add); + // declare variable node's name + GET_IR_NODE(mul_out); + GET_IR_NODE(ew_bias_add_x); + GET_IR_NODE(ew_bias_add_out); + + // trans link order of x && y for ew_bias_add op + auto ew_bias_add_desc = ew_bias_add->Op(); + IR_NODE_UNLINK(mul_out, ew_bias_add); + IR_NODE_UNLINK(ew_bias_add_x, ew_bias_add); + ew_bias_add_desc->RemoveInput("X"); + ew_bias_add_desc->RemoveInput("Y"); + ew_bias_add_desc->Flush(); + ew_bias_add_desc->SetInput("X", {mul_out->Name()}); + ew_bias_add_desc->SetInput("Y", {ew_bias_add_x->Name()}); + IR_OP_VAR_LINK(mul_out, ew_bias_add); + IR_OP_VAR_LINK(ew_bias_add_x, ew_bias_add); + + found_subgraph_count++; + }; + + gpd(graph, handler); + AddStatis(found_subgraph_count); +} + +void Conv2dBiasFusePass::FoldConv2dBias(ir::Graph* graph) const { + PADDLE_ENFORCE_NOT_NULL( + graph, platform::errors::PreconditionNotMet("graph should not be null.")); + Init(name_scope_, graph); + GraphPatternDetector gpd; + patterns::Conv2dBiasPattern pattern(gpd.mutable_pattern(), name_scope_); + + int found_subgraph_count = 0; + auto handler = [&](const GraphPatternDetector::subgraph_t& subgraph, + Graph* graph) { + VLOG(4) << "handle TransEwBiasAdd fuse"; + // declare operator node's name + GET_IR_NODE(ew_bias_add); + // declare variable node's name + GET_IR_NODE(x); + GET_IR_NODE(ew_bias_add_y); + GET_IR_NODE(ew_bias_add_out); + + auto* scope = param_scope(); + // resize 4D dims of ew_bias_add_y to 1-D dim + auto ew_bias_add_desc = ew_bias_add->Op(); + ew_bias_add_desc->SetAttr("axis", 1); + auto* ew_bias_add_y_desc = ew_bias_add_y->Var(); + auto y_shape = ew_bias_add_y_desc->GetShape(); + ew_bias_add_y_desc->SetShape({y_shape[1]}); + auto* ew_bias_add_y_tensor = + scope->GetVar(ew_bias_add_y->Name())->GetMutable(); + ew_bias_add_y_tensor->Resize(phi::make_ddim({y_shape[1]})); + ew_bias_add_desc->Flush(); + + found_subgraph_count++; + }; + + gpd(graph, handler); + AddStatis(found_subgraph_count); +} + +void Conv2dBiasFusePass::FuseScaleOps(ir::Graph* graph) const { + GraphPatternDetector gpd; + patterns::ScaleFusePattern pattern(gpd.mutable_pattern(), name_scope_); + int found_subgraph_count = 0; + + auto handler = [&](const GraphPatternDetector::subgraph_t& subgraph, + Graph* graph) { + VLOG(4) << "handle FuseScaleOps"; + /* declare operator node's name */ + GET_IR_NODE(ele_mul); + GET_IR_NODE(ele_add); + // declare variable node's name + GET_IR_NODE(x); + GET_IR_NODE(ele_mul_y); + GET_IR_NODE(ele_mul_out); + GET_IR_NODE(ele_add_y); + GET_IR_NODE(ele_add_out); + + auto* scope = param_scope(); + PADDLE_ENFORCE_NOT_NULL( + scope, platform::errors::InvalidArgument("Scope cannot be nullptr.")); + // get attrs of scale from ele_mul && ele_add + const auto& ele_mul_y_t = + scope->GetVar(ele_mul_y->Name())->GetMutable(); + auto ele_mul_y_t_len = ele_mul_y_t->numel(); + PADDLE_ENFORCE_EQ( + ele_mul_y_t_len, + 1, + platform::errors::InvalidArgument("the size(%ld) of ele_mul y tensor " + "must equal 1", + ele_mul_y_t_len)); + const auto& ele_add_y_t = + scope->GetVar(ele_add_y->Name())->GetMutable(); + auto ele_add_y_t_len = ele_add_y_t->numel(); + PADDLE_ENFORCE_EQ( + ele_add_y_t_len, + 1, + platform::errors::InvalidArgument("the size(%ld) of ele_add y tensor " + "must equal 1", + ele_mul_y_t_len)); + auto tensor_type = ele_mul_y_t->dtype(); + float scale_val_ = 1.f; + float bias_val_ = 0.f; + if (tensor_type == phi::DataType::FLOAT16) { + CastToFp32(ele_mul_y_t, nullptr); + CastToFp32(ele_add_y_t, nullptr); + } + float* ele_mul_y_ptr = + ele_mul_y_t->mutable_data(paddle::platform::CPUPlace()); + float* ele_add_y_ptr = + ele_add_y_t->mutable_data(paddle::platform::CPUPlace()); + scale_val_ = ele_mul_y_ptr[0]; + bias_val_ = ele_add_y_ptr[0]; + // replace ele_mul+ele_add with scale + OpDesc new_desc; + new_desc.SetType("scale"); + new_desc.SetAttr("bias_after_scale", true); + new_desc.SetAttr("scale", scale_val_); + new_desc.SetAttr("bias", bias_val_); + new_desc.SetInput("X", {x->Name()}); + new_desc.SetOutput("Out", {ele_add_out->Name()}); + new_desc.Flush(); + + auto fused_node = graph->CreateOpNode(&new_desc); + IR_NODE_LINK_TO(x, fused_node); + IR_NODE_LINK_TO(fused_node, ele_add_out); + + std::unordered_set del_node_set = { + ele_mul, ele_mul_y, ele_mul_out, ele_add, ele_add_y}; + GraphSafeRemoveNodes(graph, del_node_set); + found_subgraph_count++; + }; + + gpd(graph, handler); + AddStatis(found_subgraph_count); +} + +void Conv2dBiasFusePass::ApplyImpl(ir::Graph* graph) const { + PADDLE_ENFORCE_NOT_NULL( + graph, platform::errors::PreconditionNotMet("graph should not be null.")); + Init(name_scope_, graph); + // for conv2d + scale fuse + FuseScaleOps(graph); + // for conv2d + ew_bias_add + scale fuse + FoldConv2dBias(graph); + // for matmul + ew_bias_add fuse + for (auto mul_type : {"mul", "matmul", "matmul_v2"}) { + TransFcBias(graph, mul_type); + } +} + +} // namespace ir +} // namespace framework +} // namespace paddle + +REGISTER_PASS(conv2d_bias_fuse_pass, paddle::framework::ir::Conv2dBiasFusePass); + +REGISTER_PASS_CAPABILITY(conv2d_bias_fuse_pass) + .AddCombination( + paddle::framework::compatible::OpVersionComparatorCombination() + .EQ("conv2d", 0) + .EQ("mul", 0) + .LE("elementwise_add", 1)); diff --git a/paddle/fluid/framework/ir/xpu/conv2d_bias_fuse_pass.h b/paddle/fluid/framework/ir/xpu/conv2d_bias_fuse_pass.h new file mode 100644 index 0000000000000..7d9d3fbe3154c --- /dev/null +++ b/paddle/fluid/framework/ir/xpu/conv2d_bias_fuse_pass.h @@ -0,0 +1,66 @@ +// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once +#include +#include "paddle/fluid/framework/ir/fuse_pass_base.h" +#include "paddle/fluid/framework/ir/pass.h" + +namespace phi { +class DenseTensor; +} // namespace phi + +namespace paddle { +namespace framework { +class Scope; +} // namespace framework +} // namespace paddle + +namespace paddle { +namespace framework { +namespace ir { + +class Conv2dBiasFusePass : public FusePassBase { + protected: + void ApplyImpl(ir::Graph* graph) const override; + + private: + void TransFcBias(ir::Graph* graph, const std::string& mul_type) const; + + void FoldConv2dBias(ir::Graph* graph) const; + /* + For example: + x + | + elementwise_mul + | + elementwise_add + | + out + ------------------------------------------------------ + After the pass is applied: + x + | + bias --- scale_op --- scale + | + out + */ + void FuseScaleOps(ir::Graph* graph) const; + + const std::string name_scope_{"conv2d_bias_fuse_pass"}; +}; + +} // namespace ir +} // namespace framework +} // namespace paddle diff --git a/paddle/fluid/framework/ir/xpu/conv2d_xpu_fuse_pass.cc b/paddle/fluid/framework/ir/xpu/conv2d_xpu_fuse_pass.cc index 697f90e38b7ce..502c275a419d3 100644 --- a/paddle/fluid/framework/ir/xpu/conv2d_xpu_fuse_pass.cc +++ b/paddle/fluid/framework/ir/xpu/conv2d_xpu_fuse_pass.cc @@ -46,12 +46,14 @@ struct Conv2dXPUPattern : public PatternBase { const std::string& act_type, bool with_conv_bias, bool with_bn, + bool with_scale, bool with_branch_x, bool with_branch_y); // declare operator node's name PATTERN_DECL_NODE(conv); PATTERN_DECL_NODE(ew_bias_add); PATTERN_DECL_NODE(bn); + PATTERN_DECL_NODE(scale); PATTERN_DECL_NODE(ew_branch_add); PATTERN_DECL_NODE(act); // declare variable node's name @@ -69,6 +71,7 @@ struct Conv2dXPUPattern : public PatternBase { PATTERN_DECL_NODE(bn_mean_out); PATTERN_DECL_NODE(bn_saved_var); PATTERN_DECL_NODE(bn_saved_mean); + PATTERN_DECL_NODE(scale_out); PATTERN_DECL_NODE(ew_branch_add_in); PATTERN_DECL_NODE(ew_branch_add_out); PATTERN_DECL_NODE(act_out); @@ -78,6 +81,7 @@ struct Conv2dXPUPattern : public PatternBase { std::string act_type_; bool with_conv_bias_{false}; bool with_bn_{false}; + bool with_scale_{false}; bool with_branch_{false}; bool with_branch_x_{false}; bool with_branch_y_{false}; @@ -89,6 +93,7 @@ Conv2dXPUPattern::Conv2dXPUPattern(PDPattern* pattern, const std::string& act_type, bool with_conv_bias, bool with_bn, + bool with_scale, bool with_branch_x, bool with_branch_y) : PatternBase(pattern, name_scope, name_scope), @@ -96,6 +101,7 @@ Conv2dXPUPattern::Conv2dXPUPattern(PDPattern* pattern, act_type_(act_type), with_conv_bias_(with_conv_bias), with_bn_(with_bn), + with_scale_(with_scale), with_branch_(with_branch_x || with_branch_y), with_branch_x_(with_branch_x), with_branch_y_(with_branch_y) { @@ -130,7 +136,7 @@ Conv2dXPUPattern::Conv2dXPUPattern(PDPattern* pattern, pattern->NewNode(ew_bias_add_repr())->assert_is_op("elementwise_add"); ew_bias_add_out = pattern->NewNode(ew_bias_add_out_repr()) ->assert_is_op_output("elementwise_add", "Out"); - if (with_bn_ || with_branch_ || !act_type_.empty()) { + if (with_bn_ || with_scale_ || with_branch_ || !act_type_.empty()) { ew_bias_add_out->assert_has_n_outputs(1); } ew_bias_add->LinksFrom({conv_out, ew_bias_add_y}) @@ -151,6 +157,8 @@ Conv2dXPUPattern::Conv2dXPUPattern(PDPattern* pattern, PDNode* ew_branch_add = nullptr; PDNode* ew_branch_add_in = nullptr; PDNode* ew_branch_add_out = nullptr; + PDNode* scale = nullptr; + PDNode* scale_out = nullptr; PDNode* act = nullptr; PDNode* act_out = nullptr; // batch_norm op @@ -179,7 +187,7 @@ Conv2dXPUPattern::Conv2dXPUPattern(PDPattern* pattern, bn = pattern->NewNode(bn_repr())->assert_is_op("batch_norm"); bn_out = pattern->NewNode(bn_out_repr())->assert_is_op_output("batch_norm", "Y"); - if (with_branch_ || !act_type_.empty()) { + if (with_scale_ || with_branch_ || !act_type_.empty()) { bn_out->assert_has_n_outputs(1); } bn_mean_out = pattern->NewNode(bn_mean_out_repr()) @@ -196,10 +204,23 @@ Conv2dXPUPattern::Conv2dXPUPattern(PDPattern* pattern, } else { bn_out = ew_bias_add_out; } + // scale op + if (with_scale_) { + bn_out->assert_is_op_input("scale", "X"); + scale = pattern->NewNode(scale_repr())->assert_is_op("scale"); + scale_out = + pattern->NewNode(scale_out_repr())->assert_is_op_output("scale", "Out"); + if (with_bn_ || !act_type_.empty()) { + scale_out->assert_has_n_outputs(1); + } + scale->LinksFrom({bn_out}).LinksTo({scale_out}); + } else { + scale_out = bn_out; + } // ew_branch_add op if (with_branch_) { if (with_branch_x_) { - bn_out->assert_is_op_input("elementwise_add", "Y"); + scale_out->assert_is_op_input("elementwise_add", "Y"); ew_branch_add_in = pattern->NewNode(ew_branch_add_in_repr()) ->assert_is_op_input("elementwise_add", "X") ->AsInput(); @@ -226,7 +247,7 @@ Conv2dXPUPattern::Conv2dXPUPattern(PDPattern* pattern, ew_branch_add->LinksFrom({bn_out, ew_branch_add_in}) .LinksTo({ew_branch_add_out}); } else { - ew_branch_add_out = bn_out; + ew_branch_add_out = scale_out; } // act op if (!act_type_.empty()) { @@ -330,6 +351,7 @@ class Conv2dXPUFusePass : public FusePassBase { const std::string& act_type, bool with_conv_bias, bool with_bn, + bool with_scale, bool with_branch_x, bool with_branch_y) const; @@ -345,28 +367,31 @@ void Conv2dXPUFusePass::ApplyImpl(ir::Graph* graph) const { for (auto conv_type : {"conv2d", "depthwise_conv2d"}) { for (auto with_conv_bias : {true, false}) { for (auto with_bn : {true, false}) { - for (auto with_branch_x : {true, false}) { - for (auto with_branch_y : {true, false}) { - for (auto act_type : { - "relu", - "sigmoid", - "tanh", - "gelu", - "leaky_relu", - "hard_swish", - "hard_sigmoid", - "relu6", - "swish", - "", - }) { - if (with_branch_x && with_branch_y) continue; - found_subgraph_count += ApplyImpl(graph, - conv_type, - act_type, - with_conv_bias, - with_bn, - with_branch_x, - with_branch_y); + for (auto with_scale : {true, false}) { + for (auto with_branch_x : {true, false}) { + for (auto with_branch_y : {true, false}) { + for (auto act_type : { + "relu", + "sigmoid", + "tanh", + "gelu", + "leaky_relu", + "hard_swish", + "hard_sigmoid", + "relu6", + "swish", + "", + }) { + if (with_branch_x && with_branch_y) continue; + found_subgraph_count += ApplyImpl(graph, + conv_type, + act_type, + with_conv_bias, + with_bn, + with_scale, + with_branch_x, + with_branch_y); + } } } } @@ -381,6 +406,7 @@ int Conv2dXPUFusePass::ApplyImpl(ir::Graph* graph, const std::string& act_type, bool with_conv_bias, bool with_bn, + bool with_scale, bool with_branch_x, bool with_branch_y) const { GraphPatternDetector gpd; @@ -390,6 +416,7 @@ int Conv2dXPUFusePass::ApplyImpl(ir::Graph* graph, act_type, with_conv_bias, with_bn, + with_scale, with_branch_x, with_branch_y); int found_subgraph_count = 0; @@ -400,6 +427,7 @@ int Conv2dXPUFusePass::ApplyImpl(ir::Graph* graph, GET_IR_NODE(conv); GET_IR_NODE(ew_bias_add); GET_IR_NODE(bn); + GET_IR_NODE(scale); GET_IR_NODE(ew_branch_add); GET_IR_NODE(act); /* declare variable node's name*/ @@ -417,6 +445,7 @@ int Conv2dXPUFusePass::ApplyImpl(ir::Graph* graph, GET_IR_NODE(bn_mean_out); GET_IR_NODE(bn_saved_var); GET_IR_NODE(bn_saved_mean); + GET_IR_NODE(scale_out); GET_IR_NODE(ew_branch_add_in); GET_IR_NODE(ew_branch_add_out); GET_IR_NODE(act_out); @@ -429,6 +458,7 @@ int Conv2dXPUFusePass::ApplyImpl(ir::Graph* graph, auto* filter_t = scope->FindVar(conv_filter->Name())->GetMutable(); // conv_filter fp16 --> fp32 + auto filter_len = filter_t->numel(); auto filter_dtype = filter_t->dtype(); int out_dtype = proto::VarType::Type::VarType_Type_FP32; if (filter_dtype == phi::DataType::FLOAT16) { @@ -481,7 +511,6 @@ int Conv2dXPUFusePass::ApplyImpl(ir::Graph* graph, float* bn_var_ptr = bn_var_t->mutable_data(paddle::platform::CPUPlace()); auto mean_len = bn_mean_t->numel(); - auto filter_len = filter_t->numel(); auto filter_stride = filter_len / mean_len; float epsilon = PADDLE_GET_CONST(float, bn->Op()->GetAttr("epsilon")); if (!with_conv_bias) { // prev node is conv @@ -513,6 +542,34 @@ int Conv2dXPUFusePass::ApplyImpl(ir::Graph* graph, } } } + // deal with scale op + if (with_scale) { + auto bias_len = filter_dims[0]; + float scale_val_ = 1.f; + float bias_val_ = 0.f; + scale_val_ = PADDLE_GET_CONST(float, scale->Op()->GetAttr("scale")); + bias_val_ = PADDLE_GET_CONST(float, scale->Op()->GetAttr("bias")); + bool bias_after_scale_ = + PADDLE_GET_CONST(bool, scale->Op()->GetAttr("bias_after_scale")); + // recompute bias as scale op + auto fusion_bias_t = scope->GetVar(fusion_bias_node->Name()) + ->GetMutable(); + float* fusion_bias_ptr = + fusion_bias_t->mutable_data(paddle::platform::CPUPlace()); + for (int i = 0; i < bias_len; ++i) { + if (bias_after_scale_) { + fusion_bias_ptr[i] = fusion_bias_ptr[i] * scale_val_ + bias_val_; + } else { + fusion_bias_ptr[i] = (fusion_bias_ptr[i] + bias_val_) * scale_val_; + } + } + // recompute weight as scale op + float* filter_ptr = + filter_t->mutable_data(paddle::platform::CPUPlace()); + for (int i = 0; i < filter_len; ++i) { + filter_ptr[i] *= scale_val_; + } + } // filter max Node* filter_int16 = nullptr; Node* filter_max = nullptr; @@ -524,6 +581,8 @@ int Conv2dXPUFusePass::ApplyImpl(ir::Graph* graph, conv2d_xpu_out_name = act_out->Name(); } else if (ew_branch_add) { conv2d_xpu_out_name = ew_branch_add_out->Name(); + } else if (scale) { + conv2d_xpu_out_name = scale_out->Name(); } else if (bn) { conv2d_xpu_out_name = bn_out->Name(); } else if (ew_bias_add) { @@ -531,9 +590,9 @@ int Conv2dXPUFusePass::ApplyImpl(ir::Graph* graph, } else { conv2d_xpu_out_name = conv_out->Name(); } - std::string conv_out_max_name = conv2d_xpu_out_name + "_max"; - VarDesc conv_out_max_desc(conv_out_max_name); - Node* conv2d_xpu_out_max = graph->CreateVarNode(&conv_out_max_desc); + std::string conv2d_xpu_out_max_name = conv2d_xpu_out_name + "_max"; + VarDesc conv2d_xpu_out_max_desc(conv2d_xpu_out_max_name); + Node* conv2d_xpu_out_max = graph->CreateVarNode(&conv2d_xpu_out_max_desc); // Generate conv2d_xpu op framework::OpDesc conv2d_xpu_op_desc(block); // set input&output var @@ -542,7 +601,7 @@ int Conv2dXPUFusePass::ApplyImpl(ir::Graph* graph, conv2d_xpu_op_desc.SetInput("filter", {filter_int16->Name()}); conv2d_xpu_op_desc.SetInput("filter_max", {filter_max->Name()}); conv2d_xpu_op_desc.SetOutput("out", {conv2d_xpu_out_name}); - conv2d_xpu_op_desc.SetOutput("out_max", {conv_out_max_name}); + conv2d_xpu_op_desc.SetOutput("out_max", {conv2d_xpu_out_max_name}); // set fusion_bias input node if (has_bias) { conv2d_xpu_op_desc.SetInput("bias", {fusion_bias_node->Name()}); @@ -603,6 +662,8 @@ int Conv2dXPUFusePass::ApplyImpl(ir::Graph* graph, IR_NODE_LINK_TO(conv2d_xpu, act_out); } else if (ew_branch_add_out) { IR_NODE_LINK_TO(conv2d_xpu, ew_branch_add_out); + } else if (scale_out) { + IR_NODE_LINK_TO(conv2d_xpu, scale_out); } else if (bn_out) { IR_NODE_LINK_TO(conv2d_xpu, bn_out); } else if (ew_bias_add_out) { @@ -619,6 +680,9 @@ int Conv2dXPUFusePass::ApplyImpl(ir::Graph* graph, if (ew_branch_add != nullptr) { delete_nodes.insert(ew_branch_add); } + if (scale != nullptr) { + delete_nodes.insert(scale); + } if (bn != nullptr) { delete_nodes.insert(bn); delete_nodes.insert(bn_bias); @@ -630,7 +694,7 @@ int Conv2dXPUFusePass::ApplyImpl(ir::Graph* graph, delete_nodes.insert(bn_saved_var); delete_nodes.insert(bn_saved_mean); } - if (ew_bias_add) { + if (ew_bias_add != nullptr) { delete_nodes.insert(ew_bias_add); delete_nodes.insert(ew_bias_add_y); } diff --git a/paddle/fluid/framework/ir/xpu/redundant_unsqueeze_squeeze_elimination_pass.cc b/paddle/fluid/framework/ir/xpu/redundant_unsqueeze_squeeze_elimination_pass.cc index 710fa94b4e0ff..7c847ae2e9ba1 100644 --- a/paddle/fluid/framework/ir/xpu/redundant_unsqueeze_squeeze_elimination_pass.cc +++ b/paddle/fluid/framework/ir/xpu/redundant_unsqueeze_squeeze_elimination_pass.cc @@ -184,8 +184,167 @@ FoldGatherSqueeze2Pattern::FoldGatherSqueeze2Pattern( squeeze2_op->LinksFrom({gather_op_out}).LinksTo({squeeze2_op_out}); } +struct FoldConv1dSqueeze2Pattern : public PatternBase { + FoldConv1dSqueeze2Pattern(PDPattern* pattern, + const std::string& name_scope, + const std::string& act_type); + + // declare operator node's name + PATTERN_DECL_NODE(squeeze2); + PATTERN_DECL_NODE(bn); + PATTERN_DECL_NODE(act); + PATTERN_DECL_NODE(unsqueeze2); + // declare variable node's name + PATTERN_DECL_NODE(x); + PATTERN_DECL_NODE(squeeze2_out); + PATTERN_DECL_NODE(bn_bias); + PATTERN_DECL_NODE(bn_mean); + PATTERN_DECL_NODE(bn_scale); + PATTERN_DECL_NODE(bn_var); + PATTERN_DECL_NODE(bn_out); + PATTERN_DECL_NODE(bn_mean_out); + PATTERN_DECL_NODE(bn_saved_mean); + PATTERN_DECL_NODE(bn_saved_var); + PATTERN_DECL_NODE(bn_var_out); + PATTERN_DECL_NODE(act_out); + PATTERN_DECL_NODE(unsqueeze2_out); + + private: + std::string act_type_; +}; + +FoldConv1dSqueeze2Pattern::FoldConv1dSqueeze2Pattern( + PDPattern* pattern, + const std::string& name_scope, + const std::string& act_type) + : PatternBase(pattern, name_scope, name_scope), act_type_(act_type) { + auto* x = pattern->NewNode(x_repr()) + ->assert_is_op_input("squeeze2", "X") + ->assert_more([](Node* node) { + auto x_shape = node->Var()->GetShape(); + size_t x_rank = x_shape.size(); + return x_rank == 4 && x_shape[2] == 1; + }); + auto* squeeze2 = pattern->NewNode(squeeze2_repr()) + ->assert_is_op("squeeze2") + ->assert_more([](Node* node) { + auto* op_desc = node->Op(); + auto axes_array = + op_desc->GetAttrIfExists>("axes"); + return axes_array == std::vector{-2} || + axes_array == std::vector{2}; + }); + auto* squeeze2_out = pattern->NewNode(squeeze2_out_repr()) + ->assert_is_op_output("squeeze2", "Out") + ->assert_is_op_input("batch_norm", "X"); + squeeze2->LinksFrom({x}).LinksTo({squeeze2_out}); + + auto* bn_bias = pattern->NewNode(bn_bias_repr()) + ->AsInput() + ->assert_is_persistable_var() + ->assert_is_op_input("batch_norm", "Bias") + ->assert_has_n_outputs(1); + auto* bn_mean = pattern->NewNode(bn_mean_repr()) + ->AsInput() + ->assert_is_persistable_var() + ->assert_is_op_input("batch_norm", "Mean") + ->assert_has_n_outputs(1); + auto* bn_scale = pattern->NewNode(bn_scale_repr()) + ->AsInput() + ->assert_is_persistable_var() + ->assert_is_op_input("batch_norm", "Scale") + ->assert_has_n_outputs(1); + auto* bn_var = pattern->NewNode(bn_var_repr()) + ->AsInput() + ->assert_is_persistable_var() + ->assert_is_op_input("batch_norm", "Variance") + ->assert_has_n_outputs(1); + auto* bn = pattern->NewNode(bn_repr())->assert_is_op("batch_norm"); + auto* bn_out = pattern->NewNode(bn_out_repr()) + ->assert_is_op_output("batch_norm", "Y") + ->assert_is_op_input(act_type_, "X"); + auto* bn_mean_out = pattern->NewNode(bn_mean_out_repr()) + ->assert_is_op_output("batch_norm", "MeanOut"); + auto* bn_saved_mean = pattern->NewNode(bn_saved_mean_repr()) + ->assert_is_op_output("batch_norm", "SavedMean"); + auto* bn_var_out = pattern->NewNode(bn_var_out_repr()) + ->assert_is_op_output("batch_norm", "VarianceOut"); + auto* bn_saved_var = pattern->NewNode(bn_saved_var_repr()) + ->assert_is_op_output("batch_norm", "SavedVariance"); + bn->LinksFrom({squeeze2_out, bn_bias, bn_mean, bn_scale, bn_var}) + .LinksTo({bn_out, bn_mean_out, bn_var_out, bn_saved_mean, bn_saved_var}); + + auto act = pattern->NewNode(act_repr())->assert_is_op(act_type_); + auto act_out = pattern->NewNode(act_out_repr()) + ->assert_is_op_output(act_type_, "Out") + ->assert_is_op_input("unsqueeze2", "X"); + act->LinksFrom({bn_out}).LinksTo({act_out}); + + auto* unsqueeze2 = + pattern->NewNode(unsqueeze2_repr()) + ->assert_is_op("unsqueeze2") + ->assert_more([](Node* node) { + auto* op_desc = node->Op(); + auto axes_array = + op_desc->GetAttrIfExists>("axes"); + return axes_array == std::vector{-2} || + axes_array == std::vector{2}; + }); + auto* unsqueeze2_out = pattern->NewNode(unsqueeze2_out_repr()) + ->assert_is_op_output("unsqueeze2", "Out"); + unsqueeze2->LinksFrom({act_out}).LinksTo({unsqueeze2_out}); +} + } // namespace patterns +void RedundantUnsqueeze2EliminationPass::FoldConv1dSqueeze2Ops( + ir::Graph* graph, const std::string& act_type) const { + GraphPatternDetector gpd; + patterns::FoldConv1dSqueeze2Pattern pattern( + gpd.mutable_pattern(), name_scope_, act_type); + int found_subgraph_count = 0; + + auto handler = [&](const GraphPatternDetector::subgraph_t& subgraph, + Graph* graph) { + VLOG(4) << "handle FoldConv1dSqueeze2Ops"; + // declare operator node's name + GET_IR_NODE(squeeze2); + GET_IR_NODE(bn); + GET_IR_NODE(act); + GET_IR_NODE(unsqueeze2); + // declare variable node's name + GET_IR_NODE(x); + GET_IR_NODE(squeeze2_out); + GET_IR_NODE(bn_out); + GET_IR_NODE(act_out); + GET_IR_NODE(unsqueeze2_out); + + auto bn_op_desc = bn->Op(); + bn_op_desc->RenameInput(squeeze2_out->Var()->Name(), x->Var()->Name()); + bn_out->Var()->SetShape(x->Var()->GetShape()); + act_out->Var()->SetShape(x->Var()->GetShape()); + bn_op_desc->Flush(); + IR_NODE_LINK_TO(x, bn); + // behind unsqueeze op node + auto unsqueeze_out_link_nodes = unsqueeze2_out->outputs; + for (auto out_link_node : unsqueeze_out_link_nodes) { + auto op_desc = out_link_node->Op(); + op_desc->RenameInput(unsqueeze2_out->Var()->Name(), + act_out->Var()->Name()); + op_desc->Flush(); + IR_NODE_LINK_TO(act_out, out_link_node); + } + // delete useless node + std::unordered_set delete_nodes = { + squeeze2, squeeze2_out, unsqueeze2, unsqueeze2_out}; + GraphSafeRemoveNodes(graph, delete_nodes); + found_subgraph_count++; + }; + + gpd(graph, handler); + AddStatis(found_subgraph_count); +} + void RedundantUnsqueeze2EliminationPass::FoldTranspose2Ops( ir::Graph* graph, const std::string& act_type) const { GraphPatternDetector gpd; @@ -315,6 +474,9 @@ void RedundantUnsqueeze2EliminationPass::ApplyImpl(ir::Graph* graph) const { FoldTranspose2Ops(graph, act_type); } FoldGatherSqueeze2Ops(graph); + for (auto act_type : {"leaky_relu", "elu"}) { + FoldConv1dSqueeze2Ops(graph, act_type); + } } } // namespace ir diff --git a/paddle/fluid/framework/ir/xpu/redundant_unsqueeze_squeeze_elimination_pass.h b/paddle/fluid/framework/ir/xpu/redundant_unsqueeze_squeeze_elimination_pass.h index 04ed41e2b6d2d..6019c135e4dad 100644 --- a/paddle/fluid/framework/ir/xpu/redundant_unsqueeze_squeeze_elimination_pass.h +++ b/paddle/fluid/framework/ir/xpu/redundant_unsqueeze_squeeze_elimination_pass.h @@ -74,6 +74,44 @@ class RedundantUnsqueeze2EliminationPass : public FusePassBase { | */ void FoldGatherSqueeze2Ops(ir::Graph* graph) const; + /* + Origin subgraph: + x filter + | | + unsqueeze2(axes={-2}) unsqueeze2(axes={-2}) + \ / + \ / + conv2d(conv1d) + | + elementwise_add + | + squeeze2(axes={-2}) + | + batch_norm + | + act + | + unsqueeze2 + | + conv2d(conv1d) + Fused subgraph: + x filter + | | + unsqueeze2(axes={-2}) unsqueeze2(axes={-2}) + \ / + \ / + conv2d(conv1d) + | + elementwise_add + | + batch_norm + | + act + | + conv2d(conv1d) + */ + void FoldConv1dSqueeze2Ops(ir::Graph* graph, + const std::string& act_type) const; const std::string name_scope_{"redundant_unsqueeze_squeeze_elimination_pass"}; }; diff --git a/paddle/fluid/framework/new_executor/CMakeLists.txt b/paddle/fluid/framework/new_executor/CMakeLists.txt index 16b18c2d7d6bd..ae30121bc930b 100644 --- a/paddle/fluid/framework/new_executor/CMakeLists.txt +++ b/paddle/fluid/framework/new_executor/CMakeLists.txt @@ -11,13 +11,13 @@ set(STANDALONE_EXECUTOR_DEPS interpreter interpretercore_garbage_collector workqueue - pd_dialect + pd_op_dialect pd_op_to_kernel_pass phi_kernel_adaptor program_translator instruction_base pd_inplace_pass - ir) + pir) cc_library( standalone_executor diff --git a/paddle/fluid/framework/new_executor/instruction/CMakeLists.txt b/paddle/fluid/framework/new_executor/instruction/CMakeLists.txt index 8a9247859b85f..7706e462fef76 100644 --- a/paddle/fluid/framework/new_executor/instruction/CMakeLists.txt +++ b/paddle/fluid/framework/new_executor/instruction/CMakeLists.txt @@ -8,5 +8,5 @@ if(WITH_CINN AND NOT CINN_ONLY) cc_library( cinn_jit_instruction NOT_FOR_INFER SRCS cinn_jit_instruction.cc - DEPS phi cinnapi cinn_dialect runtime_dialect) + DEPS phi cinnapi cinn_op_dialect cinn_runtime_dialect) endif() diff --git a/paddle/fluid/framework/new_executor/instruction/cinn_jit_instruction.cc b/paddle/fluid/framework/new_executor/instruction/cinn_jit_instruction.cc index d56ccc7b7ba6b..8841103213400 100644 --- a/paddle/fluid/framework/new_executor/instruction/cinn_jit_instruction.cc +++ b/paddle/fluid/framework/new_executor/instruction/cinn_jit_instruction.cc @@ -14,8 +14,8 @@ #include "paddle/fluid/framework/new_executor/instruction/cinn_jit_instruction.h" -#include "paddle/cinn/hlir/dialect/runtime_dialect/ir/jit_kernel_op.h" -#include "paddle/cinn/hlir/dialect/runtime_dialect/ir/runtime_dialect.h" +#include "paddle/cinn/hlir/dialect/runtime/ir/jit_kernel_op.h" +#include "paddle/cinn/hlir/dialect/runtime/ir/runtime_dialect.h" #include "paddle/cinn/hlir/framework/instruction.h" #include "paddle/fluid/framework/paddle2cinn/transform_type.h" @@ -93,7 +93,7 @@ class CinnJitInstruction::Impl { CinnJitInstruction::CinnJitInstruction(size_t id, const platform::Place& place, - ::ir::Operation* op, + ::pir::Operation* op, Scope* scope) : InstructionBase(id, place) { // TODO(Aurelius84): We shall simplify members of JitKernelOp to make it @@ -101,6 +101,7 @@ CinnJitInstruction::CinnJitInstruction(size_t id, // responsible to construct hlir::framework::Instruction. auto jit_kernel_op = op->dyn_cast(); impl_ = std::make_shared(jit_kernel_op.instruction()); + op_ = op; } void CinnJitInstruction::Run() { diff --git a/paddle/fluid/framework/new_executor/instruction/cinn_jit_instruction.h b/paddle/fluid/framework/new_executor/instruction/cinn_jit_instruction.h index b20f6e08d9afc..5f5e4f74e8884 100644 --- a/paddle/fluid/framework/new_executor/instruction/cinn_jit_instruction.h +++ b/paddle/fluid/framework/new_executor/instruction/cinn_jit_instruction.h @@ -17,7 +17,7 @@ #include #include "paddle/fluid/framework/new_executor/instruction/instruction_base.h" -namespace ir { +namespace pir { class Operation; } @@ -29,7 +29,7 @@ class CinnJitInstruction : public InstructionBase { public: CinnJitInstruction(size_t id, const platform::Place& place, - ::ir::Operation* op, + ::pir::Operation* op, Scope* scope); // TODO(Aurelius84): Only implement core interface and need implement GC and @@ -38,9 +38,13 @@ class CinnJitInstruction : public InstructionBase { const std::string& Name() const override; + ::pir::Operation* Operation() const override { return op_; } + private: class Impl; std::shared_ptr impl_{nullptr}; + + ::pir::Operation* op_{nullptr}; // not owned }; } // namespace framework diff --git a/paddle/fluid/framework/new_executor/instruction/instruction_base.cc b/paddle/fluid/framework/new_executor/instruction/instruction_base.cc index 56dafd3132c03..6836a7f306daa 100644 --- a/paddle/fluid/framework/new_executor/instruction/instruction_base.cc +++ b/paddle/fluid/framework/new_executor/instruction/instruction_base.cc @@ -20,7 +20,7 @@ #include "paddle/fluid/framework/new_executor/interpreter/stream_analyzer.h" #include "paddle/fluid/platform/collective_helper.h" -#include "paddle/ir/core/builtin_attribute.h" +#include "paddle/pir/core/builtin_attribute.h" namespace paddle { namespace framework { @@ -90,28 +90,28 @@ void InstructionBase::AddInplace(Variable* in, Variable* out) { void InstructionBase::ClearInplace() { vec_inplace_in_to_out_.clear(); } void InstructionBase::SetInputs( - const std::unordered_map>& inputs) { + const std::unordered_map>& inputs) { input_index_ = inputs; } void InstructionBase::SetOutputs( - const std::unordered_map>& outputs) { + const std::unordered_map>& outputs) { output_index_ = outputs; } void InstructionBase::InitInputsOutputsIds( - ::ir::Operation* op, + ::pir::Operation* op, Scope* inner_scope, - const std::unordered_map<::ir::Value, std::string>& value_2_var_name, + const std::unordered_map& value_2_var_name, const std::map& var_name_2_id, const std::unordered_map& variable_2_var_name) { auto op_attributes = op->attributes(); auto op_name = - op_attributes.at("op_name").dyn_cast<::ir::StrAttribute>().AsString(); - std::unordered_map> inputs; + op_attributes.at("op_name").dyn_cast().AsString(); + std::unordered_map> inputs; for (size_t i = 0; i < op->num_operands(); i++) { - ir::Value value = op->operand_source(i); + pir::Value value = op->operand_source(i); if (value) { PADDLE_ENFORCE_NE( value_2_var_name.find(value), @@ -130,9 +130,9 @@ void InstructionBase::InitInputsOutputsIds( } SetInputs(inputs); VLOG(8) << "finish process inputs_index"; - std::unordered_map> outputs; + std::unordered_map> outputs; for (size_t i = 0; i < op->num_results(); i++) { - ir::Value value = op->result(i); + pir::Value value = op->result(i); if (value && value.type()) { PADDLE_ENFORCE_NE( value_2_var_name.find(value), diff --git a/paddle/fluid/framework/new_executor/instruction/instruction_base.h b/paddle/fluid/framework/new_executor/instruction/instruction_base.h index b8271a0ea0012..c20f46f15c716 100644 --- a/paddle/fluid/framework/new_executor/instruction/instruction_base.h +++ b/paddle/fluid/framework/new_executor/instruction/instruction_base.h @@ -22,9 +22,9 @@ #include "paddle/fluid/framework/new_executor/new_executor_defs.h" #include "paddle/fluid/platform/event.h" -namespace ir { +namespace pir { class Value; -} // namespace ir +} // namespace pir namespace paddle { namespace framework { @@ -107,29 +107,29 @@ class InstructionBase { std::map& GetMutableInplaceBackMap() { return inplace_back_map_; } const std::map& GetInplaceBackMap() { return inplace_back_map_; } - const std::unordered_map<::ir::Value, std::vector>& Inputs() const { + const std::unordered_map<::pir::Value, std::vector>& Inputs() const { return input_index_; } - std::unordered_map<::ir::Value, std::vector>& GetMutableInputs() { + std::unordered_map<::pir::Value, std::vector>& GetMutableInputs() { return input_index_; } void SetInputs( - const std::unordered_map<::ir::Value, std::vector>& inputs); + const std::unordered_map<::pir::Value, std::vector>& inputs); - const std::unordered_map<::ir::Value, std::vector>& Outputs() const { + const std::unordered_map<::pir::Value, std::vector>& Outputs() const { return output_index_; } - std::unordered_map<::ir::Value, std::vector>& GetMutableOutputs() { + std::unordered_map<::pir::Value, std::vector>& GetMutableOutputs() { return output_index_; } void SetOutputs( - const std::unordered_map<::ir::Value, std::vector>& outputs); + const std::unordered_map<::pir::Value, std::vector>& outputs); - const std::unordered_set<::ir::Value>& NoNeedBuffer() const { + const std::unordered_set<::pir::Value>& NoNeedBuffer() const { return no_need_buffer_values_; } void SetNoNeedBuffer( - const std::unordered_set<::ir::Value>& no_need_buffer_values) { + const std::unordered_set<::pir::Value>& no_need_buffer_values) { no_need_buffer_values_ = no_need_buffer_values; } @@ -137,10 +137,12 @@ class InstructionBase { virtual const std::string& Name() const = 0; + virtual ::pir::Operation* Operation() const = 0; + void InitInputsOutputsIds( - ::ir::Operation* op, + ::pir::Operation* op, Scope* inner_scope, - const std::unordered_map<::ir::Value, std::string>& value_2_var_name, + const std::unordered_map<::pir::Value, std::string>& value_2_var_name, const std::map& var_name_2_id, const std::unordered_map& variable_2_var_name); @@ -176,11 +178,11 @@ class InstructionBase { std::map inplace_back_map_; - std::unordered_map<::ir::Value, std::vector> input_index_; + std::unordered_map<::pir::Value, std::vector> input_index_; - std::unordered_map<::ir::Value, std::vector> output_index_; + std::unordered_map<::pir::Value, std::vector> output_index_; - std::unordered_set<::ir::Value> no_need_buffer_values_; + std::unordered_set<::pir::Value> no_need_buffer_values_; }; } // namespace framework diff --git a/paddle/fluid/framework/new_executor/instruction/instruction_util.cc b/paddle/fluid/framework/new_executor/instruction/instruction_util.cc index dd6aa26a1ae53..dfa8e1ec85f9f 100644 --- a/paddle/fluid/framework/new_executor/instruction/instruction_util.cc +++ b/paddle/fluid/framework/new_executor/instruction/instruction_util.cc @@ -22,22 +22,28 @@ #include "paddle/fluid/framework/new_executor/new_executor_defs.h" #include "paddle/fluid/platform/device_context.h" #include "paddle/fluid/platform/event.h" -#include "paddle/ir/core/builtin_attribute.h" -#include "paddle/ir/core/operation.h" -#include "paddle/ir/core/value.h" +#include "paddle/pir/core/builtin_attribute.h" +#include "paddle/pir/core/operation.h" +#include "paddle/pir/core/value.h" #include "paddle/fluid/framework/new_executor/interpreter/interpreter_util.h" #include "paddle/fluid/framework/new_executor/interpreter/stream_analyzer.h" -#include "paddle/fluid/ir/phi_kernel_adaptor/phi_kernel_util.h" +#include "paddle/fluid/pir/phi_kernel_adaptor/phi_kernel_util.h" +#if defined(PADDLE_WITH_NCCL) || defined(PADDLE_WITH_RCCL) #include "paddle/fluid/platform/collective_helper.h" +#include "paddle/phi/core/distributed/comm_context_manager.h" +#include "paddle/phi/core/distributed/nccl_comm_context.h" +#include "paddle/phi/core/flags.h" +PHI_DECLARE_bool(dynamic_static_unified_comm); +#endif namespace paddle { namespace framework { std::vector GetValueIds( - ir::Value value, + pir::Value value, Scope* inner_scope, - const std::unordered_map<::ir::Value, std::string>& value_2_var_name, + const std::unordered_map& value_2_var_name, const std::map& var_name_2_id, const std::unordered_map& variable_2_var_name) { @@ -56,14 +62,14 @@ std::vector GetValueIds( } platform::DeviceContext* ParseDeviceContext( - ir::Operation* op, + pir::Operation* op, platform::DeviceContext* origin_dev_ctx, const platform::Place& place, const std::string& execution_stream, const int stream_priority) { auto& op_attributes = op->attributes(); auto op_name = - op_attributes.at("op_name").dyn_cast<::ir::StrAttribute>().AsString(); + op_attributes.at("op_name").dyn_cast().AsString(); interpreter::ContextManager& ctx_manager = interpreter::ContextManager::Instance(); @@ -109,13 +115,23 @@ platform::DeviceContext* ParseDeviceContext( // c_allreduce_op.h). Now it is just a temporary solution for ONLY // c_allreduce_sum which is used in ResNet50 distributed training. if (op_name == "c_allreduce_sum" && op_attributes.at("use_calc_stream") - .dyn_cast<::ir::BoolAttribute>() + .dyn_cast() .data() == false) { int ring_id = - op_attributes.at("ring_id").dyn_cast<::ir::Int32Attribute>().data(); - return platform::NCCLCommContext::Instance() - .Get(ring_id, place) - ->dev_context(); + op_attributes.at("ring_id").dyn_cast().data(); + if (FLAGS_dynamic_static_unified_comm) { + const auto& comm_context_manager = + phi::distributed::CommContextManager::GetInstance(); + dev_ctx = static_cast( + static_cast( + comm_context_manager.Get(std::to_string(ring_id))) + ->GetDevContext()); + } else { + dev_ctx = platform::NCCLCommContext::Instance() + .Get(ring_id, place) + ->dev_context(); + } + return dev_ctx; } #endif } @@ -126,8 +142,7 @@ platform::DeviceContext* ParseDeviceContext( return origin_dev_ctx; } -OpFuncType AnalyseOpFuncType(::ir::Operation* op, - const platform::Place& place) { +OpFuncType AnalyseOpFuncType(pir::Operation* op, const platform::Place& place) { if (platform::is_cpu_place(place)) { return OpFuncType::kCpuSync; } @@ -151,21 +166,21 @@ OpFuncType AnalyseOpFuncType(::ir::Operation* op, // and so that they would be dispatched to host thread. auto& op_attributes = op->attributes(); auto op_name = - op_attributes.at("op_name").dyn_cast<::ir::StrAttribute>().AsString(); - if (op_name == kCoalesceTensor && + op_attributes.at("op_name").dyn_cast().AsString(); + if (op_name == "pd_op.coalesce_tensor" && (!platform::is_xpu_place(place) || - op->attribute("persist_output").data() == false) && - op->attribute("set_constant").data() == false && - op->attribute("copy_data").data() == false) { + op->attribute("persist_output").data() == false) && + op->attribute("set_constant").data() == false && + op->attribute("copy_data").data() == false) { return OpFuncType::kGpuSync; } // for memcpy explicitly called by user - if (platform::is_gpu_place(place) && op_name == interpreter::kMemcpyD2H) { + if (platform::is_gpu_place(place) && op_name == "pd_op.memcpy_d2h") { return OpFuncType::kGpuSync; } - if (op_name == "shape") { + if (op_name == "pd_op.shape") { return OpFuncType::kGpuSync; } return OpFuncType::kGpuAsync; diff --git a/paddle/fluid/framework/new_executor/instruction/instruction_util.h b/paddle/fluid/framework/new_executor/instruction/instruction_util.h index a41ce07957e4a..c555a101d8366 100644 --- a/paddle/fluid/framework/new_executor/instruction/instruction_util.h +++ b/paddle/fluid/framework/new_executor/instruction/instruction_util.h @@ -22,28 +22,29 @@ #include "paddle/fluid/framework/new_executor/new_executor_defs.h" #include "paddle/fluid/platform/device_context.h" #include "paddle/fluid/platform/event.h" -#include "paddle/ir/core/builtin_attribute.h" -#include "paddle/ir/core/operation.h" -#include "paddle/ir/core/value.h" +#include "paddle/pir/core/builtin_attribute.h" +#include "paddle/pir/core/operation.h" +#include "paddle/pir/core/value.h" namespace paddle { namespace framework { std::vector GetValueIds( - ir::Value value, + pir::Value value, Scope* inner_scope, - const std::unordered_map<::ir::Value, std::string>& value_2_var_name, + const std::unordered_map<::pir::Value, std::string>& value_2_var_name, const std::map& var_name_2_id, const std::unordered_map& variable_2_var_name); platform::DeviceContext* ParseDeviceContext( - ir::Operation* op, + pir::Operation* op, platform::DeviceContext* origin_dev_ctx, const platform::Place& place, const std::string& execution_stream, const int stream_priority); -OpFuncType AnalyseOpFuncType(::ir::Operation* op, const platform::Place& place); +OpFuncType AnalyseOpFuncType(::pir::Operation* op, + const platform::Place& place); } // namespace framework } // namespace paddle diff --git a/paddle/fluid/framework/new_executor/instruction/legacy_kernel_instruction.cc b/paddle/fluid/framework/new_executor/instruction/legacy_kernel_instruction.cc index 88037b15193d8..50623c6eb1118 100644 --- a/paddle/fluid/framework/new_executor/instruction/legacy_kernel_instruction.cc +++ b/paddle/fluid/framework/new_executor/instruction/legacy_kernel_instruction.cc @@ -18,11 +18,11 @@ #include "paddle/fluid/framework/new_executor/interpreter/interpreter_util.h" #include "paddle/fluid/framework/new_executor/interpreter/stream_analyzer.h" #include "paddle/fluid/framework/scope.h" -#include "paddle/fluid/ir/dialect/paddle_dialect/interface/infermeta.h" -#include "paddle/fluid/ir/dialect/paddle_dialect/interface/op_yaml_info.h" -#include "paddle/fluid/ir/dialect/paddle_dialect/ir/pd_dialect.h" -#include "paddle/fluid/ir/dialect/paddle_dialect/utils/op_yaml_info_parser.h" -#include "paddle/fluid/ir/phi_kernel_adaptor/phi_kernel_util.h" +#include "paddle/fluid/pir/dialect/operator/interface/infermeta.h" +#include "paddle/fluid/pir/dialect/operator/interface/op_yaml_info.h" +#include "paddle/fluid/pir/dialect/operator/ir/op_dialect.h" +#include "paddle/fluid/pir/dialect/operator/utils/op_yaml_info_parser.h" +#include "paddle/fluid/pir/phi_kernel_adaptor/phi_kernel_util.h" #include "paddle/fluid/platform/device_context.h" #include "paddle/phi/core/infermeta_utils.h" @@ -35,19 +35,20 @@ namespace framework { LegacyKernelInstruction::LegacyKernelInstruction( size_t id, const platform::Place& place, - ir::Operation* op, + pir::Operation* op, Scope* scope, Scope* local_scope, - const std::unordered_map<::ir::Value, std::string>& value_2_var_name, + const std::unordered_map& value_2_var_name, const std::map& var_name_2_id, const std::unordered_map& variable_2_var_name) : InstructionBase(id, place) { auto& op_attributes = op->attributes(); auto op_name = - op_attributes.at("op_name").dyn_cast<::ir::StrAttribute>().AsString(); - ir::OpInfo op_info = ir::IrContext::Instance()->GetRegisteredOpInfo(op_name); - + op_attributes.at("op_name").dyn_cast().AsString(); + pir::OpInfo op_info = + pir::IrContext::Instance()->GetRegisteredOpInfo(op_name); + op_ = op; legacy_op_name_ = op_name; VLOG(6) << "construct phi kernel instruction for: " << legacy_op_name_; @@ -55,17 +56,17 @@ LegacyKernelInstruction::LegacyKernelInstruction( // if (op_attributes.count("dist_attr") != 0) { // if (op_attributes.count("execution_stream") != 0) { // SetExecutionStream(op_attributes.at("execution_stream") - // .dyn_cast<::ir::StrAttribute>() + // .dyn_cast() // .data()); // } // if (op_attributes.count("stream_priority") != 0) { // SetStreamPriority(op_attributes.at("stream_priority") - // .dyn_cast<::ir::Int32Attribute>() + // .dyn_cast() // .data()); // } // if (op_attributes.count("scheduling_priority") != 0) { // SetSchedulingPriority(op_attributes.at("scheduling_priority") - // .dyn_cast<::ir::Int64Attribute>() + // .dyn_cast() // .data()); // } // } else { @@ -98,7 +99,7 @@ LegacyKernelInstruction::LegacyKernelInstruction( VLOG(6) << "finish process yaml_info_parser"; if (infer_meta_interface_) { - ::ir::BuildPhiContext< + pir::BuildPhiContext< phi::InferMetaContext, phi::MetaTensor, phi::MetaTensor, @@ -114,7 +115,7 @@ LegacyKernelInstruction::LegacyKernelInstruction( VLOG(6) << "finish process infer meta context"; auto kernel_name = - op_attributes.at("kernel_name").dyn_cast().AsString(); + op_attributes.at("kernel_name").dyn_cast().AsString(); auto kernel_key = op_attributes.at("kernel_key") .dyn_cast() .data(); @@ -127,7 +128,7 @@ LegacyKernelInstruction::LegacyKernelInstruction( Scope* inner_scope = local_scope == nullptr ? scope : local_scope; - operator_base_ = ir::BuildOperatorBase( + operator_base_ = pir::BuildOperatorBase( op, value_2_var_name, yaml_info_parser, variable_2_var_name, inner_scope); paddle::framework::VariableValueMap in_map; paddle::framework::VariableValueMap out_map; @@ -136,12 +137,12 @@ LegacyKernelInstruction::LegacyKernelInstruction( runtime_context_ = std::make_shared( paddle::framework::RuntimeContext(in_map, out_map)); - ir::BuildRuntimeContext(op, - value_2_var_name, - scope, - local_scope, - yaml_info_parser, - runtime_context_.get()); + pir::BuildRuntimeContext(op, + value_2_var_name, + scope, + local_scope, + yaml_info_parser, + runtime_context_.get()); kernel_context_ = new paddle::framework::ExecutionContext( *operator_base_, *local_scope, *dev_ctx, *(runtime_context_.get())); @@ -160,7 +161,7 @@ LegacyKernelInstruction::LegacyKernelInstruction( VLOG(6) << "finish process inputs outputs index"; auto& no_need_buffer_ids = yaml_info_parser.NoNeedBufferIds(); - std::unordered_set<::ir::Value> no_need_buffer_values; + std::unordered_set no_need_buffer_values; for (size_t id = 0; id < no_need_buffer_ids.size(); id++) { no_need_buffer_values.insert(op->operand_source(no_need_buffer_ids[id])); } @@ -186,6 +187,5 @@ void LegacyKernelInstruction::Run() { (*(phi_kernel_))((kernel_context_)); VLOG(6) << "Run op " << legacy_op_name_ << " kernel."; } - } // namespace framework } // namespace paddle diff --git a/paddle/fluid/framework/new_executor/instruction/legacy_kernel_instruction.h b/paddle/fluid/framework/new_executor/instruction/legacy_kernel_instruction.h index 27c1cb133bec0..9c6fbd9b7d807 100644 --- a/paddle/fluid/framework/new_executor/instruction/legacy_kernel_instruction.h +++ b/paddle/fluid/framework/new_executor/instruction/legacy_kernel_instruction.h @@ -16,10 +16,10 @@ #include "paddle/fluid/framework/new_executor/instruction/instruction_base.h" -namespace ir { +namespace pir { class Operation; class Value; -} // namespace ir +} // namespace pir namespace paddle { namespace framework { @@ -30,10 +30,10 @@ class LegacyKernelInstruction : public InstructionBase { LegacyKernelInstruction( size_t id, const platform::Place& place, - ::ir::Operation* op, + ::pir::Operation* op, Scope* scope, Scope* local_scope, - const std::unordered_map<::ir::Value, std::string>& value_2_var_name, + const std::unordered_map<::pir::Value, std::string>& value_2_var_name, const std::map& var_name_2_id, const std::unordered_map& variable_2_var_name); @@ -53,6 +53,8 @@ class LegacyKernelInstruction : public InstructionBase { const std::string& Name() const override { return legacy_op_name_; } + ::pir::Operation* Operation() const override { return op_; } + private: std::string legacy_op_name_; @@ -66,6 +68,8 @@ class LegacyKernelInstruction : public InstructionBase { std::shared_ptr operator_base_; phi::Kernel* phi_kernel_{nullptr}; // not owned + + ::pir::Operation* op_{nullptr}; // not owned }; } // namespace framework diff --git a/paddle/fluid/framework/new_executor/instruction/phi_kernel_instruction.cc b/paddle/fluid/framework/new_executor/instruction/phi_kernel_instruction.cc index 093435f8b98a2..849a83fcf2ce9 100644 --- a/paddle/fluid/framework/new_executor/instruction/phi_kernel_instruction.cc +++ b/paddle/fluid/framework/new_executor/instruction/phi_kernel_instruction.cc @@ -17,20 +17,20 @@ #include "paddle/fluid/framework/new_executor/interpreter/interpreter_util.h" #include "paddle/fluid/framework/new_executor/interpreter/stream_analyzer.h" #include "paddle/fluid/framework/scope.h" -#include "paddle/fluid/ir/dialect/paddle_dialect/interface/infermeta.h" -#include "paddle/fluid/ir/dialect/paddle_dialect/interface/op_yaml_info.h" -#include "paddle/fluid/ir/dialect/paddle_dialect/ir/pd_dialect.h" -#include "paddle/fluid/ir/dialect/paddle_dialect/utils/op_yaml_info_parser.h" -#include "paddle/fluid/ir/phi_kernel_adaptor/phi_kernel_util.h" +#include "paddle/fluid/pir/dialect/operator/interface/infermeta.h" +#include "paddle/fluid/pir/dialect/operator/interface/op_yaml_info.h" +#include "paddle/fluid/pir/dialect/operator/ir/op_dialect.h" +#include "paddle/fluid/pir/dialect/operator/utils/op_yaml_info_parser.h" +#include "paddle/fluid/pir/phi_kernel_adaptor/phi_kernel_util.h" #include "paddle/fluid/platform/collective_helper.h" #include "paddle/fluid/platform/device_context.h" #include "paddle/phi/core/infermeta_utils.h" #include "paddle/phi/core/meta_tensor.h" #include "paddle/phi/core/type_defs.h" -#include "paddle/ir/core/builtin_attribute.h" -#include "paddle/ir/core/operation.h" -#include "paddle/ir/core/value.h" +#include "paddle/pir/core/builtin_attribute.h" +#include "paddle/pir/core/operation.h" +#include "paddle/pir/core/value.h" #include "paddle/fluid/framework/new_executor/instruction/instruction_util.h" namespace paddle { @@ -39,19 +39,20 @@ namespace framework { PhiKernelInstruction::PhiKernelInstruction( size_t id, const platform::Place& place, - ir::Operation* op, + pir::Operation* op, Scope* scope, Scope* local_scope, - const std::unordered_map<::ir::Value, std::string>& value_2_var_name, + const std::unordered_map& value_2_var_name, const std::map& var_name_2_id, const std::unordered_map& variable_2_var_name) : InstructionBase(id, place) { auto op_attributes = op->attributes(); auto op_name = - op_attributes.at("op_name").dyn_cast<::ir::StrAttribute>().AsString(); - ir::OpInfo op_info = ir::IrContext::Instance()->GetRegisteredOpInfo(op_name); - + op_attributes.at("op_name").dyn_cast().AsString(); + pir::OpInfo op_info = + pir::IrContext::Instance()->GetRegisteredOpInfo(op_name); + op_ = op; phi_op_name_ = op_name; VLOG(6) << "construct phi kernel instruction for: " << phi_op_name_; @@ -59,17 +60,17 @@ PhiKernelInstruction::PhiKernelInstruction( // if (op_attributes.count("dist_attr") != 0) { // if (op_attributes.count("execution_stream") != 0) { // SetExecutionStream(op_attributes.at("execution_stream") - // .dyn_cast<::ir::StrAttribute>() + // .dyn_cast() // .data()); // } // if (op_attributes.count("stream_priority") != 0) { // SetStreamPriority(op_attributes.at("stream_priority") - // .dyn_cast<::ir::Int32Attribute>() + // .dyn_cast() // .data()); // } // if (op_attributes.count("scheduling_priority") != 0) { // SetSchedulingPriority(op_attributes.at("scheduling_priority") - // .dyn_cast<::ir::Int64Attribute>() + // .dyn_cast() // .data()); // } // } else { @@ -102,7 +103,7 @@ PhiKernelInstruction::PhiKernelInstruction( VLOG(6) << "finish process yaml_info_parser"; if (infer_meta_interface_) { - ::ir::BuildPhiContext< + pir::BuildPhiContext< phi::InferMetaContext, phi::MetaTensor, phi::MetaTensor, @@ -118,7 +119,7 @@ PhiKernelInstruction::PhiKernelInstruction( VLOG(6) << "finish process infer meta context"; auto kernel_name = - op_attributes.at("kernel_name").dyn_cast().AsString(); + op_attributes.at("kernel_name").dyn_cast().AsString(); auto kernel_key = op_attributes.at("kernel_key") .dyn_cast() .data(); @@ -129,17 +130,17 @@ PhiKernelInstruction::PhiKernelInstruction( phi_kernel_->IsValid(), true, "not found kernel for [%s]", kernel_name); VLOG(6) << "finish process select kernel"; - ::ir::BuildPhiContext, - paddle::small_vector, - true>(op, - value_2_var_name, - scope, - local_scope, - yaml_info_parser, - &kernel_context_); + pir::BuildPhiContext, + paddle::small_vector, + true>(op, + value_2_var_name, + scope, + local_scope, + yaml_info_parser, + &kernel_context_); kernel_context_.SetDeviceContext(phi::DeviceContextPool::Instance().Get( phi::TransToPhiPlace(kernel_key.backend()))); VLOG(6) << "finish process kernel context"; @@ -159,7 +160,7 @@ PhiKernelInstruction::PhiKernelInstruction( VLOG(6) << "finish process inputs outputs index"; auto& no_need_buffer_ids = yaml_info_parser.NoNeedBufferIds(); - std::unordered_set<::ir::Value> no_need_buffer_values; + std::unordered_set no_need_buffer_values; for (size_t id = 0; id < no_need_buffer_ids.size(); id++) { no_need_buffer_values.insert(op->operand_source(no_need_buffer_ids[id])); } @@ -167,6 +168,12 @@ PhiKernelInstruction::PhiKernelInstruction( VLOG(6) << "finish process no need buffer"; } +PhiKernelInstruction::~PhiKernelInstruction() { + if (phi_kernel_ != nullptr) { + delete phi_kernel_; + } +} + void PhiKernelInstruction::Run() { if (infer_meta_interface_) { infer_meta_interface_->infer_meta_(&(infer_meta_context_)); diff --git a/paddle/fluid/framework/new_executor/instruction/phi_kernel_instruction.h b/paddle/fluid/framework/new_executor/instruction/phi_kernel_instruction.h index c637cce8651fb..96484f435a9f7 100644 --- a/paddle/fluid/framework/new_executor/instruction/phi_kernel_instruction.h +++ b/paddle/fluid/framework/new_executor/instruction/phi_kernel_instruction.h @@ -16,9 +16,9 @@ #include "paddle/fluid/framework/new_executor/instruction/instruction_base.h" -namespace ir { +namespace pir { class Operation; -} // namespace ir +} // namespace pir namespace paddle { namespace framework { @@ -30,14 +30,16 @@ class PhiKernelInstruction : public InstructionBase { PhiKernelInstruction( size_t id, const platform::Place& place, - ::ir::Operation* op, + ::pir::Operation* op, Scope* scope, Scope* local_scope, - const std::unordered_map<::ir::Value, std::string>& value_2_var_name, + const std::unordered_map<::pir::Value, std::string>& value_2_var_name, const std::map& var_name_2_id, const std::unordered_map& variable_2_var_name); + ~PhiKernelInstruction(); + phi::Kernel* PhiKernel() const { return phi_kernel_; } const phi::KernelContext& KernelContext() const { return kernel_context_; } @@ -50,6 +52,8 @@ class PhiKernelInstruction : public InstructionBase { return infer_meta_interface_; } + ::pir::Operation* Operation() const override { return op_; } + void Run() override; const std::string& Name() const override { return phi_op_name_; } @@ -65,6 +69,8 @@ class PhiKernelInstruction : public InstructionBase { phi::Kernel* phi_kernel_{nullptr}; // not owned std::string phi_op_name_; + + ::pir::Operation* op_{nullptr}; // not owned }; } // namespace framework diff --git a/paddle/fluid/framework/new_executor/interpreter/interpreter_util.cc b/paddle/fluid/framework/new_executor/interpreter/interpreter_util.cc index a717a3ed09531..c6655c55fb2c3 100644 --- a/paddle/fluid/framework/new_executor/interpreter/interpreter_util.cc +++ b/paddle/fluid/framework/new_executor/interpreter/interpreter_util.cc @@ -23,15 +23,16 @@ #include "paddle/fluid/framework/new_executor/interpreter/data_transfer.h" #include "paddle/fluid/framework/new_executor/interpreter/execution_config.h" #include "paddle/fluid/framework/new_executor/interpreter/static_build.h" -#include "paddle/fluid/ir/dialect/paddle_dialect/interface/op_yaml_info.h" -#include "paddle/fluid/ir/dialect/paddle_dialect/ir/pd_dialect.h" -#include "paddle/fluid/ir/dialect/paddle_dialect/utils/op_yaml_info_parser.h" -#include "paddle/fluid/ir/phi_kernel_adaptor/phi_kernel_util.h" #include "paddle/fluid/memory/stats.h" #include "paddle/fluid/operators/controlflow/conditional_block_op_helper.h" +#include "paddle/fluid/operators/controlflow/pylayer_op_helper.h" #include "paddle/fluid/operators/controlflow/recurrent_op_helper.h" #include "paddle/fluid/operators/controlflow/while_op_helper.h" #include "paddle/fluid/operators/ops_extra_info.h" +#include "paddle/fluid/pir/dialect/operator/interface/op_yaml_info.h" +#include "paddle/fluid/pir/dialect/operator/ir/op_dialect.h" +#include "paddle/fluid/pir/dialect/operator/utils/op_yaml_info_parser.h" +#include "paddle/fluid/pir/phi_kernel_adaptor/phi_kernel_util.h" #include "paddle/fluid/platform/flags.h" #include "paddle/phi/core/distributed/comm_context_manager.h" #include "paddle/phi/core/kernel_context.h" @@ -191,7 +192,7 @@ bool IsMemcpyH2D(Instruction* instr) { } bool IsMemcpyH2D(paddle::framework::InstructionBase* instr) { - return instr->Name() == "pd.memcpy_h2d"; + return instr->Name() == "pd_op.memcpy_h2d"; } bool IsMemcpyOp(const Instruction& instr) { @@ -571,6 +572,8 @@ void BuildOpFuncList(const platform::Place& place, const ProgramDesc& main_program = *block.Program(); operators::PrepareSafeEagerDeletionOnConditionalOpAndConditionalGradOp( main_program, block.ID(), ops_unique); + operators::PrepareSafeEagerDeletionOnPyLayerOpAndPyLayerGradOp( + main_program, block.ID(), ops_unique); operators::PrepareSafeEagerDeletionOnWhileOpAndWhileGradOp( main_program, block.ID(), ops_unique); operators::PrepareSafeEagerDeletionOnRecurrentOpAndRecurrentGradOp( @@ -611,6 +614,8 @@ void BuildOpFuncList(const platform::Place& place, const std::set ops_with_var_not_in_scope = { "conditional_block", "conditional_block_grad", + "pylayer", + "pylayer_grad" "recurrent_grad", "rnn_memory_helper", "rnn_memory_helper_grad", @@ -1016,23 +1021,23 @@ void BuildOpFuncList(const platform::Place& place, void BuildOpFuncList( const platform::Place& place, - ::ir::Block* block, + pir::Block* block, std::vector* vec_func_list, framework::Scope* scope, framework::Scope* local_scope, - const std::unordered_map<::ir::Value, std::string>& value_2_name_map, + const std::unordered_map& value_2_name_map, const ExecutionConfig& execution_config) { vec_func_list->reserve(block->size()); - ::ir::IrContext* ctx = ir::IrContext::Instance(); + pir::IrContext* ctx = pir::IrContext::Instance(); - ctx->GetOrRegisterDialect(); + ctx->GetOrRegisterDialect(); for (auto op : *block) { OpFuncNode op_func_node; auto attr_map = op->attributes(); auto op_name = - attr_map.at("op_name").dyn_cast<::ir::StrAttribute>().AsString(); + attr_map.at("op_name").dyn_cast().AsString(); op_func_node.phi_op_name_ = op_name; if (GetSpecialOpNames().count(op_name)) { @@ -1040,7 +1045,7 @@ void BuildOpFuncList( continue; } - ::ir::OpInfo op_info = ctx->GetRegisteredOpInfo(op_name); + pir::OpInfo op_info = ctx->GetRegisteredOpInfo(op_name); auto impl = op_info.GetInterfaceImpl(); @@ -1051,7 +1056,7 @@ void BuildOpFuncList( VLOG(6) << "op name" << op_func_node.phi_op_name_; dialect::OpYamlInfoParser op_yaml_info_parser(impl->get_op_info_()); if (op_func_node.infer_meta_interface_) { - ::ir::BuildPhiContext< + pir::BuildPhiContext< phi::InferMetaContext, phi::MetaTensor, phi::MetaTensor, @@ -1066,7 +1071,7 @@ void BuildOpFuncList( } auto kernel_name = - attr_map.at("kernel_name").dyn_cast().AsString(); + attr_map.at("kernel_name").dyn_cast().AsString(); auto kernel_key = attr_map.at("kernel_key") .dyn_cast() .data(); @@ -1081,17 +1086,17 @@ void BuildOpFuncList( "not found kernel for [%s]", kernel_name); - ::ir::BuildPhiContext, - paddle::small_vector, - true>(op, - value_2_name_map, - scope, - local_scope, - op_yaml_info_parser, - &(op_func_node.kernel_context_)); + pir::BuildPhiContext, + paddle::small_vector, + true>(op, + value_2_name_map, + scope, + local_scope, + op_yaml_info_parser, + &(op_func_node.kernel_context_)); VLOG(6) << "finish process kernel context"; op_func_node.kernel_context_.SetDeviceContext( @@ -1184,12 +1189,12 @@ void SetDeviceCommContext(framework::OperatorBase* operator_base, } } -void SetDeviceCommContext(::ir::Operation* op, +void SetDeviceCommContext(pir::Operation* op, platform::DeviceContext* dev_ctx) { auto op_attributes = op->attributes(); if (op_attributes.count("ring_id") != 0) { int ring_id = - op_attributes.at("ring_id").dyn_cast<::ir::Int32Attribute>().data(); + op_attributes.at("ring_id").dyn_cast().data(); const auto& comm_context_manager = phi::distributed::CommContextManager::GetInstance(); if (comm_context_manager.Has(std::to_string(ring_id))) { @@ -1200,7 +1205,7 @@ void SetDeviceCommContext(::ir::Operation* op, } else { VLOG(3) << "op: " << op_attributes.at("op_name") - .dyn_cast<::ir::StrAttribute>() + .dyn_cast() .AsString() << ", ring_id: " << ring_id << ", get comm_context failed!"; } @@ -1211,11 +1216,11 @@ std::unordered_set GetSpecialOpNames() { return { "builtin.combine", "builtin.slice", - "pd.feed", + "pd_op.feed", "builtin.set_parameter", "builtin.get_parameter", - "pd.data", - "pd.shadow_output", + "pd_op.data", + "pd_op.shadow_output", }; } @@ -1229,6 +1234,32 @@ void BuildId2VarName(const std::map& var_name_2_id, } } +const std::vector GetInstructionCallStack( + const std::string& type, const pir::AttributeMap& attrs) { + std::vector vec_str; + if (attrs.count("sub_block") != 0) { + return vec_str; + } + auto iter = attrs.find(OpProtoAndCheckerMaker::OpCreationCallstackAttrName()); + if (iter != attrs.end()) { + auto attr = iter->second; + PADDLE_ENFORCE( + attr.isa(), + paddle::platform::errors::InvalidArgument( + "%s: Callstack attributes of %s is not ArrayAttribute type", type)); + pir::ArrayAttribute array_attribute = attr.dyn_cast(); + std::vector vec_attr = array_attribute.AsVector(); + for (auto value : vec_attr) { + PADDLE_ENFORCE( + value.isa(), + paddle::platform::errors::InvalidArgument( + "%s: Callstack attributes of %s is not StrAttribute type", type)); + vec_str.emplace_back(value.dyn_cast().AsString()); + } + } + return vec_str; +} + } // namespace interpreter } // namespace framework } // namespace paddle diff --git a/paddle/fluid/framework/new_executor/interpreter/interpreter_util.h b/paddle/fluid/framework/new_executor/interpreter/interpreter_util.h index 33b89cac542d4..413db7e75ecd4 100644 --- a/paddle/fluid/framework/new_executor/interpreter/interpreter_util.h +++ b/paddle/fluid/framework/new_executor/interpreter/interpreter_util.h @@ -106,11 +106,11 @@ void BuildOpFuncList(const platform::Place& place, void BuildOpFuncList( const platform::Place& place, - ::ir::Block* block, + ::pir::Block* block, std::vector* vec_func_list, framework::Scope* scope, framework::Scope* local_scope, - const std::unordered_map<::ir::Value, std::string>& value_2_name_map, + const std::unordered_map<::pir::Value, std::string>& value_2_name_map, const ExecutionConfig& execution_config); void BuildVariableScope(const framework::BlockDesc& block, @@ -124,10 +124,13 @@ void LogDeviceMemoryStats(const platform::Place& place); void SetDeviceCommContext(framework::OperatorBase* operator_base, platform::DeviceContext* dev_ctx); -void SetDeviceCommContext(::ir::Operation* op, +void SetDeviceCommContext(::pir::Operation* op, platform::DeviceContext* dev_ctx); std::unordered_set GetSpecialOpNames(); + +const std::vector GetInstructionCallStack( + const std::string& type, const pir::AttributeMap& attrs); } // namespace interpreter } // namespace framework } // namespace paddle diff --git a/paddle/fluid/framework/new_executor/interpreter/plan.cc b/paddle/fluid/framework/new_executor/interpreter/plan.cc index 0217219302f6d..ce2f8b2718ff3 100644 --- a/paddle/fluid/framework/new_executor/interpreter/plan.cc +++ b/paddle/fluid/framework/new_executor/interpreter/plan.cc @@ -41,7 +41,7 @@ Plan::Plan(const std::vector>& job_list, Plan::Plan( const std::vector>& job_list, - const std::unordered_map>& + const std::unordered_map>& type_to_ir_program) : job_list_(job_list), type_to_ir_program_(type_to_ir_program), @@ -69,7 +69,7 @@ const ProgramDesc* Plan::Program(const std::string& job_type) const { return type_to_program_.at(job_type); } -std::shared_ptr<::ir::Program> Plan::IrProgram( +std::shared_ptr<::pir::Program> Plan::IrProgram( const std::string& job_type) const { return type_to_ir_program_.at(job_type); } diff --git a/paddle/fluid/framework/new_executor/interpreter/plan.h b/paddle/fluid/framework/new_executor/interpreter/plan.h index aac750a38f97b..8ce66db821305 100644 --- a/paddle/fluid/framework/new_executor/interpreter/plan.h +++ b/paddle/fluid/framework/new_executor/interpreter/plan.h @@ -21,8 +21,8 @@ #include "paddle/fluid/framework/new_executor/interpreter/job.h" #include "paddle/fluid/framework/program_desc.h" -#include "paddle/ir/core/program.h" #include "paddle/phi/core/macros.h" +#include "paddle/pir/core/program.h" namespace paddle { namespace framework { @@ -33,7 +33,7 @@ class Plan final { Plan(const std::vector>& job_list, const std::unordered_map& type_to_program); Plan(const std::vector>& job_list, - const std::unordered_map>& + const std::unordered_map>& type_to_ir_program); ~Plan() = default; @@ -41,14 +41,14 @@ class Plan final { const std::vector>& JobList() const; const ProgramDesc* Program(const std::string& job_type) const; - std::shared_ptr<::ir::Program> IrProgram(const std::string& job_type) const; + std::shared_ptr<::pir::Program> IrProgram(const std::string& job_type) const; int64_t MicroBatchNum() const; private: const std::vector> job_list_; const std::unordered_map type_to_program_; - const std::unordered_map> + const std::unordered_map> type_to_ir_program_; int64_t micro_batch_num_; }; diff --git a/paddle/fluid/framework/new_executor/interpreter/stream_analyzer.cc b/paddle/fluid/framework/new_executor/interpreter/stream_analyzer.cc index 3dc9175dbfd4b..bbbaf4c0dd75f 100644 --- a/paddle/fluid/framework/new_executor/interpreter/stream_analyzer.cc +++ b/paddle/fluid/framework/new_executor/interpreter/stream_analyzer.cc @@ -19,8 +19,14 @@ #include "paddle/fluid/framework/new_executor/instruction/instruction_base.h" #include "paddle/fluid/framework/new_executor/interpreter/interpreter_util.h" -#include "paddle/fluid/platform/collective_helper.h" #include "paddle/fluid/platform/device_context.h" +#if defined(PADDLE_WITH_NCCL) || defined(PADDLE_WITH_RCCL) +#include "paddle/fluid/platform/collective_helper.h" +#include "paddle/phi/core/distributed/comm_context_manager.h" +#include "paddle/phi/core/distributed/nccl_comm_context.h" +#include "paddle/phi/core/flags.h" +PHI_DECLARE_bool(dynamic_static_unified_comm); +#endif namespace paddle { namespace framework { @@ -235,9 +241,20 @@ DeviceContext* StreamAnalyzer::ParseDeviceContext( if (op_type == "c_allreduce_sum" && op->Attr("use_calc_stream") == false) { int ring_id = op->Attr("ring_id"); - return platform::NCCLCommContext::Instance() - .Get(ring_id, place_) - ->dev_context(); + + if (FLAGS_dynamic_static_unified_comm) { + const auto& comm_context_manager = + phi::distributed::CommContextManager::GetInstance(); + dev_ctx = static_cast( + static_cast( + comm_context_manager.Get(std::to_string(ring_id))) + ->GetDevContext()); + } else { + dev_ctx = platform::NCCLCommContext::Instance() + .Get(ring_id, place_) + ->dev_context(); + } + return dev_ctx; } #endif } @@ -257,7 +274,7 @@ const std::unordered_set no_need_buffer_ins(Instruction* instr) { return std::unordered_set(); } -const std::unordered_set no_need_buffer_ins( +const std::unordered_set no_need_buffer_ins( const paddle::framework::InstructionBase* instr) { return instr->NoNeedBuffer(); } @@ -471,9 +488,9 @@ void analyse_event_info_for_two_instructions< // fused_var share the same tensor. However, as the dependency is implicit, we // can only add event for it with the help of depend_op. - if (has_data_dependency( + if (has_data_dependency( instructions[cur_instr_id], instructions[next_instr_id]) || - instructions[next_instr_id]->Name() == "pd.depend") { + instructions[next_instr_id]->Name() == "pd_op.depend") { waiter_instr_ids->insert(next_instr_id); return; } diff --git a/paddle/fluid/framework/new_executor/interpretercore.cc b/paddle/fluid/framework/new_executor/interpretercore.cc index 384c668ed2e56..a2c3c49e1c634 100644 --- a/paddle/fluid/framework/new_executor/interpretercore.cc +++ b/paddle/fluid/framework/new_executor/interpretercore.cc @@ -16,8 +16,8 @@ #include "paddle/fluid/framework/new_executor/new_ir_interpreter.h" #include "paddle/fluid/framework/new_executor/program_interpreter.h" -#include "paddle/ir/core/program.h" -#include "paddle/ir/core/value.h" +#include "paddle/pir/core/program.h" +#include "paddle/pir/core/value.h" PADDLE_DEFINE_EXPORTED_bool( new_executor_serial_run, @@ -50,7 +50,7 @@ InterpreterCore::InterpreterCore(const platform::Place& place, InterpreterCore::InterpreterCore( const platform::Place& place, const std::vector& fetch_var_names, - std::unique_ptr<::ir::Program> ir_prog, + std::unique_ptr<::pir::Program> ir_prog, framework::Scope* scope, const ExecutionConfig& execution_config) { VLOG(4) << "InterpreterCore(): " << this << " on " << place; diff --git a/paddle/fluid/framework/new_executor/interpretercore.h b/paddle/fluid/framework/new_executor/interpretercore.h index f01c12b27c3a1..52df30cbfd976 100644 --- a/paddle/fluid/framework/new_executor/interpretercore.h +++ b/paddle/fluid/framework/new_executor/interpretercore.h @@ -17,9 +17,9 @@ PD_DECLARE_bool(new_executor_use_local_scope); -namespace ir { +namespace pir { class Program; -} // namespace ir +} // namespace pir namespace paddle { namespace framework { @@ -38,7 +38,7 @@ class InterpreterCore { // This constructor is for New IR. InterpreterCore(const platform::Place& place, const std::vector& fetch_var_names, - std::unique_ptr<::ir::Program> ir_prog, + std::unique_ptr<::pir::Program> ir_prog, Scope* scope, const ExecutionConfig& execution_config = ExecutionConfig()); ~InterpreterCore(); diff --git a/paddle/fluid/framework/new_executor/new_executor_defs.h b/paddle/fluid/framework/new_executor/new_executor_defs.h index bf0c0880f385d..ee9f17034a45f 100644 --- a/paddle/fluid/framework/new_executor/new_executor_defs.h +++ b/paddle/fluid/framework/new_executor/new_executor_defs.h @@ -20,7 +20,7 @@ #include "paddle/fluid/framework/operator.h" #include "paddle/fluid/framework/variable_helper.h" -#include "paddle/fluid/ir/dialect/paddle_dialect/interface/infermeta.h" +#include "paddle/fluid/pir/dialect/operator/interface/infermeta.h" #include "paddle/fluid/platform/device_event_base.h" #include "paddle/fluid/platform/event.h" #include "paddle/phi/core/utils/rw_lock.h" diff --git a/paddle/fluid/framework/new_executor/new_ir_interpreter.cc b/paddle/fluid/framework/new_executor/new_ir_interpreter.cc index 94ef1e3af217e..78225dee6f337 100644 --- a/paddle/fluid/framework/new_executor/new_ir_interpreter.cc +++ b/paddle/fluid/framework/new_executor/new_ir_interpreter.cc @@ -41,16 +41,15 @@ #endif #include "paddle/fluid/framework/new_executor/instruction/legacy_kernel_instruction.h" #include "paddle/fluid/framework/new_executor/instruction/phi_kernel_instruction.h" -#include "paddle/fluid/ir/dialect/paddle_dialect/utils/utils.h" -#include "paddle/fluid/ir/dialect/paddle_kernel_dialect/ir/kernel_attribute.h" -#include "paddle/fluid/ir/dialect/paddle_kernel_dialect/ir/kernel_dialect.h" -#include "paddle/fluid/ir/dialect/paddle_kernel_dialect/ir/kernel_op.h" -#include "paddle/fluid/ir/dialect/paddle_kernel_dialect/ir/kernel_type.h" -#include "paddle/fluid/ir/phi_kernel_adaptor/phi_kernel_util.h" -#include "paddle/ir/core/builtin_attribute.h" +#include "paddle/fluid/pir/dialect/kernel/ir/kernel_attribute.h" +#include "paddle/fluid/pir/dialect/kernel/ir/kernel_dialect.h" +#include "paddle/fluid/pir/dialect/kernel/ir/kernel_op.h" +#include "paddle/fluid/pir/dialect/kernel/ir/kernel_type.h" +#include "paddle/fluid/pir/dialect/operator/utils/utils.h" +#include "paddle/fluid/pir/phi_kernel_adaptor/phi_kernel_util.h" +#include "paddle/pir/core/builtin_attribute.h" PHI_DECLARE_bool(enable_new_ir_in_executor); - PHI_DECLARE_bool(enable_new_ir_in_executor_trace_run); namespace paddle { @@ -59,7 +58,7 @@ namespace framework { NewIRInterpreter::NewIRInterpreter( const platform::Place& place, const std::vector& fetch_var_names, - std::unique_ptr<::ir::Program> ir_prog, + std::unique_ptr<::pir::Program> ir_prog, framework::Scope* scope, const ExecutionConfig& execution_config) : place_(place), @@ -349,79 +348,79 @@ void NewIRInterpreter::UpdateSyncOpNum() { void NewIRInterpreter::UpdateNcclOpNum() { static std::set nccl_op_set = { - "pd.c_softmax_with_cross_entropy", - "pd.c_allgather", - "pd.c_allreduce_max", - "pd.c_allreduce_min", - "pd.c_allreduce_sum", - "pd.c_allreduce_prod", - "pd.c_reduce_max", - "pd.c_reduce_min", - "pd.c_reduce_prod", - "pd.c_reducescatter", - "pd.c_broadcast", - "pd.c_broadcast_", - "pd.c_scatter", - "pd.partial_send", - "pd.partial_recv", - "pd.partial_allgather", - "pd.recv_v2", - "pd.send_v2", - "pd.mp_allreduce_sum", - "pd.barrier", - "pd.alltoall", - "pd.global_gather", - "pd.distributed_fused_lamb", - "pd.margin_cross_entropy", - "pd.sync_batch_norm", - "pd.sync_batch_norm_", - "pd.data_norm", - "pd.class_center_sample", - "pd.all_to_all", - "pd.dist_concat", - "pd.all_gather", - "pd.broadcast", - "pd.p_recv", - "pd.p_send", - "pd.reduce_scatter", - "pd.all_reduce", - "pd.reduce", - "pd.c_softmax_with_cross_entropy_grad", - "pd.c_allgather_grad", - "pd.c_allreduce_max_grad", - "pd.c_allreduce_min_grad", - "pd.c_allreduce_sum_grad", - "pd.c_allreduce_prod_grad", - "pd.c_reduce_max_grad", - "pd.c_reduce_min_grad", - "pd.c_reduce_prod_grad", - "pd.c_reducescatter_grad", - "pd.c_broadcast_grad", - "pd.c_scatter_grad", - "pd.partial_send_grad", - "pd.partial_recv_grad", - "pd.partial_allgather_grad", - "pd.recv_v2_grad", - "pd.send_v2_grad", - "pd.mp_allreduce_sum_grad", - "pd.barrier_grad", - "pd.alltoall_grad", - "pd.global_gather_grad", - "pd.distributed_fused_lamb_grad", - "pd.margin_cross_entropy_grad", - "pd.margin_cross_entropy_grad_" - "pd.sync_batch_norm_grad", - "pd.data_norm_grad", - "pd.class_center_sample_grad", - "pd.all_to_all_grad", - "pd.dist_concat_grad", - "pd.all_gather_grad", - "pd.broadcast_grad", - "pd.p_recv_grad", - "pd.p_send_grad", - "pd.reduce_scatter_grad", - "pd.all_reduce_grad", - "pd.reduce_grad"}; + "pd_op.c_softmax_with_cross_entropy", + "pd_op.c_allgather", + "pd_op.c_allreduce_max", + "pd_op.c_allreduce_min", + "pd_op.c_allreduce_sum", + "pd_op.c_allreduce_prod", + "pd_op.c_reduce_max", + "pd_op.c_reduce_min", + "pd_op.c_reduce_prod", + "pd_op.c_reducescatter", + "pd_op.c_broadcast", + "pd_op.c_broadcast_", + "pd_op.c_scatter", + "pd_op.partial_send", + "pd_op.partial_recv", + "pd_op.partial_allgather", + "pd_op.recv_v2", + "pd_op.send_v2", + "pd_op.mp_allreduce_sum", + "pd_op.barrier", + "pd_op.alltoall", + "pd_op.global_gather", + "pd_op.distributed_fused_lamb", + "pd_op.margin_cross_entropy", + "pd_op.sync_batch_norm", + "pd_op.sync_batch_norm_", + "pd_op.data_norm", + "pd_op.class_center_sample", + "pd_op.all_to_all", + "pd_op.dist_concat", + "pd_op.all_gather", + "pd_op.broadcast", + "pd_op.p_recv", + "pd_op.p_send", + "pd_op.reduce_scatter", + "pd_op.all_reduce", + "pd_op.reduce", + "pd_op.c_softmax_with_cross_entropy_grad", + "pd_op.c_allgather_grad", + "pd_op.c_allreduce_max_grad", + "pd_op.c_allreduce_min_grad", + "pd_op.c_allreduce_sum_grad", + "pd_op.c_allreduce_prod_grad", + "pd_op.c_reduce_max_grad", + "pd_op.c_reduce_min_grad", + "pd_op.c_reduce_prod_grad", + "pd_op.c_reducescatter_grad", + "pd_op.c_broadcast_grad", + "pd_op.c_scatter_grad", + "pd_op.partial_send_grad", + "pd_op.partial_recv_grad", + "pd_op.partial_allgather_grad", + "pd_op.recv_v2_grad", + "pd_op.send_v2_grad", + "pd_op.mp_allreduce_sum_grad", + "pd_op.barrier_grad", + "pd_op.alltoall_grad", + "pd_op.global_gather_grad", + "pd_op.distributed_fused_lamb_grad", + "pd_op.margin_cross_entropy_grad", + "pd_op.margin_cross_entropy_grad_" + "pd_op.sync_batch_norm_grad", + "pd_op.data_norm_grad", + "pd_op.class_center_sample_grad", + "pd_op.all_to_all_grad", + "pd_op.dist_concat_grad", + "pd_op.all_gather_grad", + "pd_op.broadcast_grad", + "pd_op.p_recv_grad", + "pd_op.p_send_grad", + "pd_op.reduce_scatter_grad", + "pd_op.all_reduce_grad", + "pd_op.reduce_grad"}; int64_t nccl_op_num = 0; for (auto& ins : vec_instruction_base_) { if (nccl_op_set.count(ins->Name())) { @@ -512,7 +511,7 @@ void NewIRInterpreter::BuildInstruction() { } else if (op->dialect()->name() == "pd_kernel") { auto op_name = op->attributes() .at("op_name") - .dyn_cast<::ir::StrAttribute>() + .dyn_cast<::pir::StrAttribute>() .AsString(); if (interpreter::GetSpecialOpNames().count(op_name)) { VLOG(6) << "skip process " << op_name; @@ -542,7 +541,7 @@ void NewIRInterpreter::BuildInstruction() { variable_2_var_name_)); } #ifdef PADDLE_WITH_CINN - } else if (op->dialect()->name() == "cinn") { + } else if (op->dialect()->name() == "cinn_runtime") { vec_instruction_base_.emplace_back( std::make_unique(op_idx++, place_, op, scope_)); #endif @@ -634,7 +633,7 @@ void NewIRInterpreter::BuildInstructionDependences() { void NewIRInterpreter::RecordMemcpyD2H(InstructionBase* instr_node) { // NOTE(zhiqiu): hot fix for jit input var - if (instr_node->Name() == "pd.memcpy_d2h") { + if (instr_node->Name() == "pd_op.memcpy_d2h") { platform::DeviceContextPool& pool = platform::DeviceContextPool::Instance(); auto* default_dev_ctx = pool.Get(place_); for (auto& event : instr_node->EventsToWait()) { @@ -781,14 +780,14 @@ void NewIRInterpreter::CalculateLastLiveOps() { InstructionBase* instr = vec_instruction_base_[op_idx].get(); std::set gc_check_vars; - const std::unordered_map<::ir::Value, std::vector>& ins = + const std::unordered_map<::pir::Value, std::vector>& ins = instr->Inputs(); - const std::unordered_map<::ir::Value, std::vector>& outs = + const std::unordered_map<::pir::Value, std::vector>& outs = instr->Outputs(); - std::unordered_multimap<::ir::Value, std::vector> ins_and_outs{ + std::unordered_multimap<::pir::Value, std::vector> ins_and_outs{ ins.begin(), ins.end()}; - if (instr->Name() != "pd.fetch") { + if (instr->Name() != "pd_op.fetch") { ins_and_outs.insert(outs.begin(), outs.end()); } @@ -879,7 +878,8 @@ void NewIRInterpreter::ConstructEventForJitInput() { for (size_t i = 0; i < dependecy_count_->size(); ++i) { if ((*dependecy_count_)[i] == 0) { InstructionBase* inst = vec_instruction_base_[i].get(); - if (inst->Name() == "pd.memcpy_d2h" && platform::is_gpu_place(place_)) { + if (inst->Name() == "pd_op.memcpy_d2h" && + platform::is_gpu_place(place_)) { for (auto& item : inst->Inputs()) { for (auto var_id : item.second) { auto name = GetNameById(var_id); @@ -919,13 +919,13 @@ FetchList NewIRInterpreter::Run(const std::vector& feed_names, // Build std::stringstream ss; ss << this; - ::ir::BuildScope(*ir_program_->block(), - InnerScope(), - ss.str(), - &value_2_var_name_, - &variable_2_var_name_, - &var_name_2_id_, - &variable_list_); + ::pir::BuildScope(*ir_program_->block(), + InnerScope(), + ss.str(), + &value_2_var_name_, + &variable_2_var_name_, + &var_name_2_id_, + &variable_list_); interpreter::BuildId2VarName(var_name_2_id_, &id_2_var_name_); @@ -1240,6 +1240,10 @@ void NewIRInterpreter::RunInstructionBase(InstructionBase* instr_node) { VLOG(5) << "after run kernel"; instr_node->RecordEvent(place_); } catch (platform::EnforceNotMet& ex) { + auto* op = instr_node->Operation(); + const std::vector op_callstack_attr = + interpreter::GetInstructionCallStack(op->name(), op->attributes()); + framework::InsertCallStackInfo(op->name(), op_callstack_attr, &ex); LOG(WARNING) << instr_node->Name() << " raises an EnforceNotMet exception " << platform::demangle(typeid(ex).name()) << ", " << ex.what(); exception_holder_.Catch(std::make_exception_ptr(std::move(ex))); @@ -1281,7 +1285,7 @@ void NewIRInterpreter::PreAnalysis() { VLOG(4) << "Done UpdateNcclOpNum"; } -::ir::Value NewIRInterpreter::GetValueByName(const std::string& var_name) { +::pir::Value NewIRInterpreter::GetValueByName(const std::string& var_name) { for (auto kv : value_2_var_name_) { if (kv.second == var_name) { return kv.first; @@ -1293,16 +1297,16 @@ ::ir::Value NewIRInterpreter::GetValueByName(const std::string& var_name) { void NewIRInterpreter::SolvePersisableVarNames() { VLOG(6) << "SolvePersisableVarNames"; for (auto kv : value_2_var_name_) { - ::ir::Value value = kv.first; + ::pir::Value value = kv.first; const std::string& var_name = kv.second; - ::ir::OpResult result = value.dyn_cast<::ir::OpResult>(); + ::pir::OpResult result = value.dyn_cast<::pir::OpResult>(); auto* defining_op = value.GetDefiningOp(); if (defining_op->HasAttribute(kAttrIsPersisable)) { auto is_persisables = defining_op->attribute(kAttrIsPersisable) - .dyn_cast<::ir::ArrayAttribute>() + .dyn_cast<::pir::ArrayAttribute>() .AsVector(); if (is_persisables[result.GetResultIndex()] - .dyn_cast<::ir::BoolAttribute>() + .dyn_cast<::pir::BoolAttribute>() .data()) { VLOG(6) << "parameter_var_names_ include: " << var_name; parameter_var_names_.insert(var_name); diff --git a/paddle/fluid/framework/new_executor/new_ir_interpreter.h b/paddle/fluid/framework/new_executor/new_ir_interpreter.h index b37b26d107560..c0681a277d5f7 100644 --- a/paddle/fluid/framework/new_executor/new_ir_interpreter.h +++ b/paddle/fluid/framework/new_executor/new_ir_interpreter.h @@ -16,7 +16,7 @@ #include #include "paddle/fluid/framework/new_executor/instruction/instruction_base.h" #include "paddle/fluid/framework/new_executor/interpreter_base_impl.h" -#include "paddle/ir/core/value.h" +#include "paddle/pir/core/value.h" namespace ir { class Program; @@ -36,7 +36,7 @@ class NewIRInterpreter : public InterpreterBaseImpl { public: NewIRInterpreter(const platform::Place& place, const std::vector& fetch_var_names, - std::unique_ptr<::ir::Program> ir_prog, + std::unique_ptr<::pir::Program> ir_prog, Scope* scope, const ExecutionConfig& execution_config = ExecutionConfig()); @@ -184,7 +184,7 @@ class NewIRInterpreter : public InterpreterBaseImpl { void RecordMemcpyD2H(InstructionBase* instr_node); - ::ir::Value GetValueByName(const std::string& var_name); + ::pir::Value GetValueByName(const std::string& var_name); void CheckGC(InstructionBase* instr); @@ -198,11 +198,11 @@ class NewIRInterpreter : public InterpreterBaseImpl { InstructionSchedulingPriorityLess ir_instruction_scheduling_priority_less; - std::unique_ptr<::ir::Program> ir_program_{nullptr}; + std::unique_ptr<::pir::Program> ir_program_{nullptr}; std::vector> vec_instruction_base_; - std::unordered_map<::ir::Value, std::string> value_2_var_name_; + std::unordered_map<::pir::Value, std::string> value_2_var_name_; std::unordered_map variable_2_var_name_; diff --git a/paddle/fluid/framework/new_executor/program_interpreter.cc b/paddle/fluid/framework/new_executor/program_interpreter.cc index f59d5812273c3..a29e45515d894 100644 --- a/paddle/fluid/framework/new_executor/program_interpreter.cc +++ b/paddle/fluid/framework/new_executor/program_interpreter.cc @@ -32,6 +32,10 @@ #include "paddle/phi/backends/device_manager.h" #if defined(PADDLE_WITH_NCCL) || defined(PADDLE_WITH_RCCL) #include "paddle/fluid/platform/device/gpu/nccl_helper.h" +#include "paddle/phi/core/distributed/comm_context_manager.h" +#include "paddle/phi/core/distributed/nccl_comm_context.h" +#include "paddle/phi/core/flags.h" +PHI_DECLARE_bool(dynamic_static_unified_comm); #endif namespace paddle { @@ -1204,10 +1208,18 @@ void ProgramInterpreter::RecordStreamForGC(const Instruction& instr) { auto operator_base_ptr = instr.OpBase(); if ((operator_base_ptr->Type() == "send_v2") && (operator_base_ptr->Attr("use_calc_stream") == false)) { - stream = platform::NCCLCommContext::Instance() - .Get(operator_base_ptr->Attr("ring_id"), - instr.DeviceContext().GetPlace()) - ->stream(); + int ring_id = operator_base_ptr->Attr("ring_id"); + if (FLAGS_dynamic_static_unified_comm) { + const auto& comm_context_manager = + phi::distributed::CommContextManager::GetInstance(); + stream = static_cast( + comm_context_manager.Get(std::to_string(ring_id))) + ->GetStream(); + } else { + stream = platform::NCCLCommContext::Instance() + .Get(ring_id, instr.DeviceContext().GetPlace()) + ->stream(); + } } #endif auto TensorRecordStream = [&stream](phi::DenseTensor& tensor) { diff --git a/paddle/fluid/framework/new_executor/standalone_executor.cc b/paddle/fluid/framework/new_executor/standalone_executor.cc index ed109f9cd0b96..a2ae422b814a3 100644 --- a/paddle/fluid/framework/new_executor/standalone_executor.cc +++ b/paddle/fluid/framework/new_executor/standalone_executor.cc @@ -19,13 +19,13 @@ #include "paddle/fluid/platform/profiler/event_tracing.h" #include "paddle/phi/core/flags.h" -#include "paddle/fluid/ir/transforms/pd_op_to_kernel_pass.h" +#include "paddle/fluid/pir/transforms/pd_op_to_kernel_pass.h" -#include "paddle/fluid/ir/transforms/inplace_pass.h" #include "paddle/fluid/ir_adaptor/translator/translate.h" -#include "paddle/ir/core/program.h" -#include "paddle/ir/pass/pass.h" -#include "paddle/ir/pass/pass_manager.h" +#include "paddle/fluid/pir/transforms/inplace_pass.h" +#include "paddle/pir/core/program.h" +#include "paddle/pir/pass/pass.h" +#include "paddle/pir/pass/pass_manager.h" PHI_DECLARE_bool(enable_new_ir_in_executor); PHI_DECLARE_bool(enable_new_ir_api); @@ -54,7 +54,7 @@ StandaloneExecutor::StandaloneExecutor(const platform::Place& place, for (const auto& job : jobs) { const std::string& job_type = job->Type(); std::shared_ptr program = nullptr; - std::shared_ptr<::ir::Program> ir_program = nullptr; + std::shared_ptr<::pir::Program> ir_program = nullptr; if (FLAGS_enable_new_ir_api) { ir_program = plan_.IrProgram(job_type); } else { @@ -79,18 +79,18 @@ StandaloneExecutor::StandaloneExecutor(const platform::Place& place, // TODO(phlrain) we only support cpu for now if (FLAGS_enable_new_ir_in_executor) { - std::shared_ptr<::ir::Program> base_program = ir_program; + std::shared_ptr<::pir::Program> base_program = ir_program; if (!FLAGS_enable_new_ir_api) { VLOG(6) << "begin to translate" << std::endl; base_program = paddle::TranslateLegacyProgramToProgram(*program); } auto block = base_program->block(); for (auto it = block->begin(); it != block->end(); ++it) { - if ((*it)->name() == "pd.fetch") { + if ((*it)->name() == "pd_op.fetch") { size_t index = (*it) ->attributes() .at("col") - .dyn_cast() + .dyn_cast() .data(); if (fetch_var_names_.size() < index + 1) { @@ -100,7 +100,7 @@ StandaloneExecutor::StandaloneExecutor(const platform::Place& place, fetch_var_names_[index] = (*it) ->attributes() .at("name") - .dyn_cast() + .dyn_cast() .AsString() + "@fetch"; } @@ -109,8 +109,8 @@ StandaloneExecutor::StandaloneExecutor(const platform::Place& place, paddle::dialect::PdOpLowerToKernelPass(base_program.get(), place); if (FLAGS_new_ir_apply_inplace_pass) { - ir::PassManager pm(ir::IrContext::Instance(), 3); - pm.AddPass(ir::CreateInplacePass()); + pir::PassManager pm(pir::IrContext::Instance(), 3); + pm.AddPass(pir::CreateInplacePass()); pm.Run(kernel_program.get()); } diff --git a/paddle/fluid/framework/new_executor/standalone_executor.h b/paddle/fluid/framework/new_executor/standalone_executor.h index bec52add981bf..e9ee5509d20be 100644 --- a/paddle/fluid/framework/new_executor/standalone_executor.h +++ b/paddle/fluid/framework/new_executor/standalone_executor.h @@ -24,7 +24,7 @@ #include "paddle/fluid/framework/new_executor/new_executor_defs.h" #include "paddle/fluid/framework/program_desc.h" #include "paddle/fluid/platform/place.h" -#include "paddle/ir/core/program.h" +#include "paddle/pir/core/program.h" namespace paddle { namespace framework { diff --git a/paddle/fluid/framework/op_call_stack.cc b/paddle/fluid/framework/op_call_stack.cc index b9a7aad1fdf4a..f7b60af104747 100644 --- a/paddle/fluid/framework/op_call_stack.cc +++ b/paddle/fluid/framework/op_call_stack.cc @@ -13,9 +13,7 @@ See the License for the specific language governing permissions and limitations under the License. */ #include "paddle/fluid/framework/op_call_stack.h" - #include - #include "paddle/fluid/framework/op_proto_maker.h" namespace paddle { @@ -34,7 +32,7 @@ std::string InsertIndentationIntoEachLine(const std::string &str) { } void InsertCallStackInfo(const std::string &type, - const AttributeMap &attrs, + const paddle::framework::AttributeMap &attrs, platform::EnforceNotMet *exception) { if (attrs.count("sub_block") != 0) { return; @@ -76,6 +74,39 @@ void InsertCallStackInfo(const std::string &type, exception->set_error_str(sout.str()); } +void InsertCallStackInfo(const std::string &type, + const std::vector &callstack_attr_str, + platform::EnforceNotMet *exception) { + const std::vector *callstack = &callstack_attr_str; + std::ostringstream sout; + // Step 1. Construct python call stack string + if (callstack) { + if (FLAGS_call_stack_level > 1) { + sout << "\n\n Compile Traceback (most recent call last):"; + } else { + sout << "In user code:\n"; + } + for (auto &line : *callstack) { + sout << "\n " << line; + } + } + VLOG(1) << exception->error_str(); + // Step 2. Construct final call stack & append error op name + if (FLAGS_call_stack_level > 1) { + sout << exception->what(); + } else { + // If callstack exists, use err_str_ instead sub_err_str_ + if (callstack) { + sout << "\n\n"; + sout << InsertIndentationIntoEachLine(exception->error_str()); + } else { + sout << exception->simple_error_str(); + } + } + sout << " [operator < " << type << " > error]"; + exception->set_error_str(sout.str()); +} + void AppendErrorOpHint(const std::string &type, platform::EnforceNotMet *exception) { std::ostringstream sout; diff --git a/paddle/fluid/framework/op_call_stack.h b/paddle/fluid/framework/op_call_stack.h index 0cd10df89b86c..9f9ecd14ef8be 100644 --- a/paddle/fluid/framework/op_call_stack.h +++ b/paddle/fluid/framework/op_call_stack.h @@ -24,7 +24,11 @@ namespace framework { // insert python call stack & append error op for exception message void InsertCallStackInfo(const std::string &type, - const AttributeMap &attrs, + const paddle::framework::AttributeMap &attrs, + platform::EnforceNotMet *exception); + +void InsertCallStackInfo(const std::string &type, + const std::vector &callstack_attr_str, platform::EnforceNotMet *exception); // only append error op for exception message diff --git a/paddle/fluid/framework/op_call_stack_test.cc b/paddle/fluid/framework/op_call_stack_test.cc index 23bb25270ccc8..dee60aa0fe3de 100644 --- a/paddle/fluid/framework/op_call_stack_test.cc +++ b/paddle/fluid/framework/op_call_stack_test.cc @@ -44,6 +44,7 @@ TEST(OpCallStack, InsertCallStackInfo) { stack_test_vec.emplace_back(stack_test_str); attr_map["op_callstack"] = stack_test_vec; paddle::framework::InsertCallStackInfo("test", attr_map, &exception); + paddle::framework::InsertCallStackInfo("test", stack_test_vec, &exception); std::string ex_msg = exception.what(); EXPECT_TRUE(ex_msg.find(stack_test_str) != std::string::npos); EXPECT_TRUE(ex_msg.find("[operator < test > error]") != std::string::npos); diff --git a/paddle/fluid/framework/op_desc.cc b/paddle/fluid/framework/op_desc.cc index ab74b2691b062..a2eef6417870a 100644 --- a/paddle/fluid/framework/op_desc.cc +++ b/paddle/fluid/framework/op_desc.cc @@ -25,6 +25,8 @@ limitations under the License. */ #include "paddle/fluid/framework/var_type_inference.h" #include "paddle/fluid/operators/ops_extra_info.h" #include "paddle/phi/common/complex.h" +#include "paddle/pir/core/block.h" +#include "paddle/pir/core/value.h" #include "paddle/utils/blank.h" namespace paddle { @@ -964,7 +966,12 @@ struct SetAttrDescVisitor { void operator()(const std::vector &v) const { VectorToRepeated(v, attr_->mutable_bools()); } - + void operator()(const std::vector &v) const { + // just do nothing. + } + void operator()(const std::vector &v) const { + // just do nothing. + } void operator()(const std::vector &v) const { std::vector var_names; for (auto var : v) { diff --git a/paddle/fluid/framework/paddle2cinn/cinn_compiler.cc b/paddle/fluid/framework/paddle2cinn/cinn_compiler.cc index e0ddafd37da70..ff898db3819f6 100644 --- a/paddle/fluid/framework/paddle2cinn/cinn_compiler.cc +++ b/paddle/fluid/framework/paddle2cinn/cinn_compiler.cc @@ -48,9 +48,9 @@ #include "paddle/fluid/operators/cinn/cinn_launch_context.h" #include "paddle/fluid/platform/enforce.h" #include "paddle/fluid/string/string_helper.h" -#include "paddle/ir/core/program.h" -#include "paddle/ir/core/value.h" #include "paddle/phi/core/flags.h" +#include "paddle/pir/core/program.h" +#include "paddle/pir/core/value.h" #include "paddle/utils/flags.h" PHI_DECLARE_bool(enable_pe_launch_cinn); diff --git a/paddle/fluid/framework/type_defs.h b/paddle/fluid/framework/type_defs.h index 961b7c1e663c0..4ad1bcb80c4bc 100644 --- a/paddle/fluid/framework/type_defs.h +++ b/paddle/fluid/framework/type_defs.h @@ -25,6 +25,8 @@ limitations under the License. */ #include "paddle/fluid/imperative/type_defs.h" #include "paddle/phi/common/scalar.h" +#include "paddle/pir/core/block.h" +#include "paddle/pir/core/value.h" #include "paddle/utils/blank.h" #include "paddle/utils/small_vector.h" #include "paddle/utils/variant.h" @@ -62,7 +64,9 @@ using Attribute = paddle::variant, double, paddle::experimental::Scalar, - std::vector>; + std::vector, + ::pir::Block*, + std::vector<::pir::Value>>; using AttributeMap = std::unordered_map; using OpCreator = diff --git a/paddle/fluid/framework/type_info.cc b/paddle/fluid/framework/type_info.cc index cb7dae540d119..03086f46ad216 100644 --- a/paddle/fluid/framework/type_info.cc +++ b/paddle/fluid/framework/type_info.cc @@ -16,7 +16,7 @@ limitations under the License. */ #include "paddle/fluid/framework/feed_fetch_type.h" #include "paddle/fluid/framework/raw_tensor.h" #include "paddle/fluid/framework/string_array.h" -#include "paddle/fluid/ir/dialect/paddle_dialect/ir/pd_meta_tensor.h" +#include "paddle/fluid/pir/dialect/operator/ir/meta_tensor.h" #include "paddle/fluid/prim/utils/static/desc_tensor.h" #include "paddle/fluid/primitive/type/lazy_tensor.h" diff --git a/paddle/fluid/inference/CMakeLists.txt b/paddle/fluid/inference/CMakeLists.txt index 48c9f79f34de1..da39c21e84c03 100644 --- a/paddle/fluid/inference/CMakeLists.txt +++ b/paddle/fluid/inference/CMakeLists.txt @@ -103,8 +103,8 @@ set(SHARED_INFERENCE_SRCS # NOTE(Aurelius84): For inference library, some DEPS is usless # such as non-infer operator related targets et.al. -list(REMOVE_ITEM fluid_modules cinn_dialect) -# NOTE(Aurelisu84): Remove ir dialect related target DEPS for inference +list(REMOVE_ITEM fluid_modules cinn_op_dialect) +# NOTE(Aurelisu84): Remove pir dialect related target DEPS for inference # shared library to prune library size. list(REMOVE_ITEM fluid_modules ${not_infer_modules}) diff --git a/paddle/fluid/inference/api/analysis_predictor.cc b/paddle/fluid/inference/api/analysis_predictor.cc index 83f75c1ae0703..1e3be4d0cfcd3 100644 --- a/paddle/fluid/inference/api/analysis_predictor.cc +++ b/paddle/fluid/inference/api/analysis_predictor.cc @@ -2010,12 +2010,9 @@ std::unique_ptr AnalysisPredictor::GetInputTensor( } } else if (platform::is_custom_place(place_)) { auto custom_place = place_; - auto paddleplace = static_cast( - static_cast(PaddlePlace::kCUSTOM) + - phi::CustomRegisteredDeviceMap::Instance() - .GetOrRegisterGlobalDeviceTypeId(place_.GetDeviceType())); - res->SetPlace( - paddleplace, custom_place.GetDeviceId(), place_.GetDeviceType()); + res->SetPlace(PaddlePlace::kCUSTOM, + custom_place.GetDeviceId(), + custom_place.GetDeviceType()); } else { auto gpu_place = place_; res->SetPlace(PaddlePlace::kGPU, gpu_place.GetDeviceId()); @@ -2064,12 +2061,9 @@ std::unique_ptr AnalysisPredictor::GetOutputTensor( } } else if (platform::is_custom_place(place_)) { auto custom_place = place_; - auto paddleplace = static_cast( - static_cast(PaddlePlace::kCUSTOM) + - phi::CustomRegisteredDeviceMap::Instance() - .GetOrRegisterGlobalDeviceTypeId(place_.GetDeviceType())); - res->SetPlace( - paddleplace, custom_place.GetDeviceId(), place_.GetDeviceType()); + res->SetPlace(PaddlePlace::kCUSTOM, + custom_place.GetDeviceId(), + custom_place.GetDeviceType()); } else { auto gpu_place = place_; res->SetPlace(PaddlePlace::kGPU, gpu_place.GetDeviceId()); @@ -2893,6 +2887,7 @@ USE_TRT_CONVERTER(sign); #endif USE_TRT_CONVERTER(rsqrt); USE_TRT_CONVERTER(fused_preln_embedding_eltwise_layernorm) +USE_TRT_CONVERTER(prompt_tuning_emb_eltwise_layernorm); USE_TRT_CONVERTER(fused_embedding_eltwise_layernorm); USE_TRT_CONVERTER(preln_skip_layernorm) USE_TRT_CONVERTER(fused_bias_dropout_residual_layer_norm) diff --git a/paddle/fluid/inference/api/details/zero_copy_tensor.cc b/paddle/fluid/inference/api/details/zero_copy_tensor.cc index 193e244f86e38..7a399bb55fe7b 100644 --- a/paddle/fluid/inference/api/details/zero_copy_tensor.cc +++ b/paddle/fluid/inference/api/details/zero_copy_tensor.cc @@ -244,16 +244,11 @@ void Tensor::CopyFromCpu(const T *data) { "Can not create tensor with XPU place because paddle is not compiled " "with XPU.")); #endif - } else { + } else if (place_ == PlaceType::kCUSTOM) { #ifdef PADDLE_WITH_CUSTOM_DEVICE - auto device_type_id = - static_cast(place_) - static_cast(PlaceType::kCUSTOM); paddle::platform::DeviceContextPool &pool = paddle::platform::DeviceContextPool::Instance(); - paddle::platform::CustomPlace custom_place( - phi::CustomRegisteredDeviceMap::Instance().GetGlobalDeviceType( - device_type_id), - device_); + paddle::platform::CustomPlace custom_place(device_type_, device_); auto *t_data = tensor->mutable_data(custom_place); auto *dev_ctx = static_cast( pool.Get(custom_place)); @@ -264,9 +259,15 @@ void Tensor::CopyFromCpu(const T *data) { ele_size, dev_ctx->stream()); #else - PADDLE_THROW(paddle::platform::errors::InvalidArgument( - "The analysis predictor supports CPU, GPU and XPU now.")); + PADDLE_THROW(paddle::platform::errors::Unavailable( + "Can not create tensor with Custom place because paddle is not " + "compiled " + "with XPU.")); #endif + } else { + PADDLE_THROW(paddle::platform::errors::InvalidArgument( + "The analysis predictor supports CPU, GPU, XPU and CUSTOM_DEVICE " + "now.")); } } @@ -355,6 +356,14 @@ void Tensor::ShareExternalData(const T *data, const_cast(data), size, paddle::platform::XPUPlace(device_)), meta); *tensor = std::move(dtensor); + } else if (place == PlaceType::kCUSTOM) { + phi::DenseTensor dtensor( + std::make_shared( + const_cast(data), + size, + paddle::platform::CustomPlace(device_type_, device_)), + meta); + *tensor = std::move(dtensor); } else { PADDLE_THROW(paddle::platform::errors::InvalidArgument( "PlaceType must be one of [PlaceType::kCPU, PlaceType::kGPU, " diff --git a/paddle/fluid/inference/api/paddle_pass_builder.cc b/paddle/fluid/inference/api/paddle_pass_builder.cc index ba71eff17387d..2058525946914 100644 --- a/paddle/fluid/inference/api/paddle_pass_builder.cc +++ b/paddle/fluid/inference/api/paddle_pass_builder.cc @@ -89,23 +89,24 @@ void PaddlePassBuilder::ClearPasses() { passes_.clear(); } const std::vector kTRTSubgraphPasses({ "trt_support_nhwc_pass", - "adaptive_pool2d_convert_global_pass", // - "trt_map_ops_to_matrix_multiply_pass", // - "shuffle_channel_detect_pass", // - "quant_conv2d_dequant_fuse_pass", // - "delete_quant_dequant_op_pass", // - "delete_quant_dequant_filter_op_pass", // - "trt_delete_weight_dequant_linear_op_pass", // - "delete_quant_dequant_linear_op_pass", // - "identity_op_clean_pass", // - "add_support_int8_pass", // - "simplify_with_basic_ops_pass", // - "trt_embedding_eltwise_layernorm_fuse_pass", // - "preln_embedding_eltwise_layernorm_fuse_pass", // - "trt_multihead_matmul_fuse_pass_v2", // - "trt_multihead_matmul_fuse_pass_v3", // - "multihead_matmul_roformer_fuse_pass", // - "constant_folding_pass", // + "adaptive_pool2d_convert_global_pass", // + "trt_map_ops_to_matrix_multiply_pass", // + "shuffle_channel_detect_pass", // + "quant_conv2d_dequant_fuse_pass", // + "delete_quant_dequant_op_pass", // + "delete_quant_dequant_filter_op_pass", // + "trt_delete_weight_dequant_linear_op_pass", // + "delete_quant_dequant_linear_op_pass", // + "identity_op_clean_pass", // + "add_support_int8_pass", // + "simplify_with_basic_ops_pass", // + "trt_prompt_tuning_embedding_eltwise_layernorm_fuse_pass", // + "trt_embedding_eltwise_layernorm_fuse_pass", // + "preln_embedding_eltwise_layernorm_fuse_pass", // + "trt_multihead_matmul_fuse_pass_v2", // + "trt_multihead_matmul_fuse_pass_v3", // + "multihead_matmul_roformer_fuse_pass", // + "constant_folding_pass", // #ifdef PADDLE_WITH_TENSORRT #if !IS_TRT_VERSION_GE(8610) "trt_flash_multihead_matmul_fuse_pass", // @@ -527,8 +528,9 @@ XpuPassStrategy::XpuPassStrategy() : PassStrategy({}) { "one_beam_size_fuse_pass", "fold_interp_outsize_fuse_pass", "fold_two_squeeze2_fuse_pass", - "conv1d_xpu_fuse_pass", + // "conv1d_xpu_fuse_pass", "duplicated_transpose_fuse_pass", + "conv2d_bias_fuse_pass", "redundant_unsqueeze_squeeze_elimination_pass", "reduce_ops_fuse_pass", "delete_cast_op_pass", diff --git a/paddle/fluid/inference/tensorrt/convert/CMakeLists.txt b/paddle/fluid/inference/tensorrt/convert/CMakeLists.txt index dad1b073d51f2..2471c365e29ed 100755 --- a/paddle/fluid/inference/tensorrt/convert/CMakeLists.txt +++ b/paddle/fluid/inference/tensorrt/convert/CMakeLists.txt @@ -115,7 +115,7 @@ list( if(${TENSORRT_MAJOR_VERSION} GREATER_EQUAL 7) list(APPEND CONVERT_FILES emb_eltwise_layernorm.cc - preln_emb_eltwise_layernorm.cc) + preln_emb_eltwise_layernorm.cc prompt_tuning_emb_eltwise_layernorm.cc) endif() if(CUSPARSELT_FOUND AND ${TENSORRT_MAJOR_VERSION} GREATER_EQUAL 8) diff --git a/paddle/fluid/inference/tensorrt/convert/prompt_tuning_emb_eltwise_layernorm.cc b/paddle/fluid/inference/tensorrt/convert/prompt_tuning_emb_eltwise_layernorm.cc new file mode 100644 index 0000000000000..f6e99461695c5 --- /dev/null +++ b/paddle/fluid/inference/tensorrt/convert/prompt_tuning_emb_eltwise_layernorm.cc @@ -0,0 +1,177 @@ +/* Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at +http://www.apache.org/licenses/LICENSE-2.0 +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include "paddle/fluid/inference/tensorrt/convert/op_converter.h" +#include "paddle/fluid/inference/tensorrt/convert/utils.h" +#include "paddle/fluid/inference/tensorrt/engine.h" +#include "paddle/fluid/inference/tensorrt/helper.h" +#include "paddle/phi/core/ddim.h" + +namespace paddle { +namespace inference { +namespace tensorrt { + +class PromptTuningEmbEltwiseLayerNormOpConverter : public OpConverter { + public: + void operator()(const framework::proto::OpDesc& op, + const framework::Scope& scope, + bool test_mode) override { + VLOG(4) << "convert fused_prompt_tuning_embedding_eltwise_layernorm op to " + "tensorrt layer"; + // get the presistable var's data + auto GetWeight = [&](const std::string& var_name, + framework::DDim* dim) -> TensorRTEngine::Weight { + auto* temp_var = scope.FindVar(var_name); + auto* temp_tensor = temp_var->GetMutable(); + *dim = temp_tensor->dims(); + auto weight = engine_->GetTrtWeight(var_name, *temp_tensor); + return weight; + }; + + framework::OpDesc op_desc(op, nullptr); + auto* dense_vector = engine_->GetITensor(op_desc.Input("DenseVector")[0]); + + auto pos_id_name = engine_->tensorrt_transformer_posid(); + auto mask_id_name = engine_->tensorrt_transformer_maskid(); + + // bool with_fp16 = engine_->WithFp16() && + // !engine_->disable_trt_plugin_fp16(); int hidden = 0; Declare inputs + std::vector input_ids; + + // Declare inputs_weight + std::vector input_embs; + std::vector emb_sizes; + TensorRTEngine::Weight weight; + framework::DDim emb_dims; + framework::DDim bias_dims, scale_dims; + TensorRTEngine::Weight bias_weight, scale_weight; + + int64_t bias_size = phi::product(bias_dims); + int64_t scale_size = phi::product(scale_dims); + bool enable_int8 = op_desc.HasAttr("enable_int8"); + + std::vector id_names = op_desc.Input("Ids"); + std::vector emb_names = op_desc.Input("Embs"); + int input_num = id_names.size(); + + engine_->SetITensor("pos_id", engine_->GetITensor(pos_id_name)); + engine_->SetITensor("mask_id", engine_->GetITensor(mask_id_name)); + for (int i = 0; i < input_num; i++) { + auto input_tensor = engine_->GetITensor(id_names[i]); + weight = GetWeight(emb_names[i], &emb_dims); + if (id_names[i] == pos_id_name) { + input_ids.insert(input_ids.begin(), input_tensor); + input_embs.insert(input_embs.begin(), weight.get()); + emb_sizes.insert(emb_sizes.begin(), weight.get().count); + } else { + input_ids.push_back(input_tensor); + input_embs.push_back(weight.get()); + emb_sizes.push_back(weight.get().count); + } + } + bias_weight = GetWeight(op_desc.Input("Bias").front(), &bias_dims); + scale_weight = GetWeight(op_desc.Input("Scale").front(), &scale_dims); + bias_size = phi::product(bias_dims); + scale_size = phi::product(scale_dims); + // other_id(except pos_id) + engine_->SetITensor("word_id", input_ids[1]); + + int output_fp16 = static_cast((engine_->WithFp16() == 1) ? 1 : 0); + if (enable_int8) { + output_fp16 = 1; + } + PADDLE_ENFORCE_EQ( + output_fp16, + 1, + platform::errors::InvalidArgument( + "Only Precision::KHalf(fp16) is supported when infering " + "ernie(bert) model with config.EnableVarseqlen(). " + "But Precision::KFloat32 is setted.")); + + std::vector fields; + std::vector temp_fields_keys; + fields.emplace_back("bert_embeddings_layernorm_beta", + bias_weight.get().values, + GetPluginFieldType(bias_weight.get().type), + static_cast(bias_size)); + fields.emplace_back("bert_embeddings_layernorm_gamma", + scale_weight.get().values, + GetPluginFieldType(scale_weight.get().type), + static_cast(scale_size)); + fields.emplace_back( + "output_fp16", &output_fp16, nvinfer1::PluginFieldType::kINT32, 1); + for (int i = 0; i < input_num; ++i) { + temp_fields_keys.push_back("bert_embeddings_word_embeddings_" + + std::to_string(i)); + fields.emplace_back(temp_fields_keys.rbegin()->c_str(), + input_embs[i].values, + GetPluginFieldType(input_embs[i].type), + static_cast(emb_sizes[i])); + } + + nvinfer1::PluginFieldCollection* plugin_ptr = + static_cast( + malloc(sizeof(*plugin_ptr) + + fields.size() * sizeof(nvinfer1::PluginField))); + plugin_ptr->nbFields = static_cast(fields.size()); + plugin_ptr->fields = fields.data(); + + std::vector plugin_inputs = input_ids; + plugin_inputs.emplace_back( + engine_->GetITensor("mask_id")); // input mask_id + + plugin_inputs.emplace_back(dense_vector); // prompt_tuning'dense_vector + + auto creator = GetPluginRegistry()->getPluginCreator( + "PromptTuningEmbLayerNormVarlenPluginDynamic", "1"); + auto plugin_obj = creator->createPlugin( + "PromptTuningEmbLayerNormVarlenPluginDynamic", plugin_ptr); + + auto plugin_layer = engine_->network()->addPluginV2( + plugin_inputs.data(), plugin_inputs.size(), *plugin_obj); + + plugin_layer->setName( + ("PromptTuningEmbLayerNormVarlenPluginDynamicV1(Output: " + + op_desc.Output("Out")[0] + ")") + .c_str()); + free(plugin_ptr); + if (enable_int8) { + float out_scale = + PADDLE_GET_CONST(float, op_desc.GetAttr("out_threshold")); + engine_->SetTensorDynamicRange(plugin_layer->getOutput(0), + out_scale); // output + engine_->SetTensorDynamicRange(plugin_layer->getOutput(1), + out_scale); // mask + engine_->SetTensorDynamicRange(plugin_layer->getOutput(2), + out_scale); // max seqlen + } + + engine_->DeleteITensor("mask_id", engine_->GetITensor("mask_id")); + engine_->DeleteITensor("pos_id", engine_->GetITensor("pos_id")); + + auto output_name = op_desc.Output("Out")[0]; + RreplenishLayerAndOutput(plugin_layer, + "PromptTuningEmbLayerNormVarlenPluginDynamicV1", + {output_name, + std::string("qkv_plugin_mask"), + std::string("max_seqlen_tensor"), + std::string("mask_id"), + std::string("pos_id")}, + test_mode); + } +}; + +} // namespace tensorrt +} // namespace inference +} // namespace paddle + +REGISTER_TRT_OP_CONVERTER(prompt_tuning_emb_eltwise_layernorm, + PromptTuningEmbEltwiseLayerNormOpConverter); diff --git a/paddle/fluid/inference/tensorrt/op_teller.cc b/paddle/fluid/inference/tensorrt/op_teller.cc index 2378cf97be982..a00d97a21fb47 100644 --- a/paddle/fluid/inference/tensorrt/op_teller.cc +++ b/paddle/fluid/inference/tensorrt/op_teller.cc @@ -2870,6 +2870,7 @@ struct SimpleOpTypeSetTeller : public Teller { "relu6", "hard_sigmoid", "clip", + "prompt_tuning_emb_eltwise_layernorm", "fused_embedding_eltwise_layernorm", "multihead_matmul", "multihead_matmul_roformer", @@ -3036,6 +3037,7 @@ struct SimpleOpTypeSetTeller : public Teller { "relu6", "hard_sigmoid", "clip", + "prompt_tuning_emb_eltwise_layernorm", "fused_embedding_eltwise_layernorm", "multihead_matmul", "multihead_matmul_roformer", diff --git a/paddle/fluid/inference/tensorrt/plugin/CMakeLists.txt b/paddle/fluid/inference/tensorrt/plugin/CMakeLists.txt index b1df5a733623e..bfc9e6b9072da 100644 --- a/paddle/fluid/inference/tensorrt/plugin/CMakeLists.txt +++ b/paddle/fluid/inference/tensorrt/plugin/CMakeLists.txt @@ -40,7 +40,9 @@ list( elementwiseadd_transpose_op_plugin.cu generic_plugin.cu many_emb_layernorm_plugin.cu - many_emb_layernorm_kernel.cu) + many_emb_layernorm_kernel.cu + prompt_tuning_emb_layernorm_varseqlen_kernel_hface.cu + prompt_tuning_emb_layernorm_varseqlen_plugin.cu) if(${TENSORRT_MAJOR_VERSION} GREATER_EQUAL 7) list(APPEND TRT_FILES many_emb_layernorm_varseqlen_plugin.cu diff --git a/paddle/fluid/inference/tensorrt/plugin/prompt_tuning_emb_layernorm_varseqlen_kernel_hface.cu b/paddle/fluid/inference/tensorrt/plugin/prompt_tuning_emb_layernorm_varseqlen_kernel_hface.cu new file mode 100644 index 0000000000000..919ff565870a8 --- /dev/null +++ b/paddle/fluid/inference/tensorrt/plugin/prompt_tuning_emb_layernorm_varseqlen_kernel_hface.cu @@ -0,0 +1,204 @@ +// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +// SPDX-FileCopyrightText: Copyright (c) 1993-2023 NVIDIA CORPORATION & +// AFFILIATES. All rights reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include +#include +#include +#include +#include + +#include "NvInfer.h" +#include "paddle/fluid/inference/tensorrt/plugin/common/common.cuh" +#include "paddle/fluid/inference/tensorrt/plugin/prompt_tuning_emb_layernorm_varseqlen_plugin.h" + +namespace paddle { +namespace inference { +namespace tensorrt { +namespace plugin { + +template +__global__ void prompt_tuning_embKernel(int32_t B, + int32_t ld, + int32_t const* inputIds0, + int32_t const* inputIds1, + int32_t const* inputIds2, + T const* dense_vector, + float const* beta, + float const* gamma, + T const* mIdsEmbDev0, + T const* mIdsEmbDev1, + T const* mIdsEmbDev2, + int32_t IdsSize0, + int32_t IdsSize1, + int32_t IdsSize2, + T* output, + int32_t* new_pos_id) { + cub::Sum pairSum; + int32_t const s = blockIdx.x; + int32_t const b = blockIdx.y; + int32_t const sumS = inputIds0[b]; + int32_t const s_b = inputIds0[b + 1] - inputIds0[b]; + + int32_t const new_sumS = sumS + b; + + // new pos_id: Add an id to each sentence + new_pos_id[b] = new_sumS; + + // last id + if (b == B - 1) { + new_pos_id[B] = inputIds0[B] + B; + } + + T const rld = T(1.f) / T(ld); + int32_t const seqPos = sumS + s; + int32_t const out_seqPos = new_sumS + s + 1; + int32_t const new_out_seqPos = new_sumS + s; + + kvp threadData(0, 0); + + int32_t const new_outoffset = new_out_seqPos * ld; + int32_t const prompt_tuning_offset = new_sumS * ld; + int32_t const dense_vector_offset = b * ld; + + if (s < s_b) { + extern __shared__ int32_t word_id[]; + if (threadIdx.x == 0) { + if (static_cast(inputIds1)[seqPos] < 0 || + static_cast(inputIds1)[seqPos] >= IdsSize1) { + printf( + "Error!!!!!!(embLayerNormVarSeqlenPlugin): ID cannot be lookup " + "table: ID < 0 or ID > max "); + return; + } else { + word_id[0] = static_cast(inputIds1)[seqPos]; + } + + if (static_cast(inputIds2)[seqPos] < 0 || + static_cast(inputIds2)[seqPos] >= IdsSize2) { + printf( + "Error!!!!!!(embLayerNormVarSeqlenPlugin): ID cannot be lookup " + "table: ID < 0 or ID > max "); + return; + } else { + word_id[1] = static_cast(inputIds2)[seqPos]; + } + } + __syncthreads(); + + // 2. load pos/tok/word embeddings and add them toghether + // offset into embeddings is given by wordId * hidden_size + int32_t const poffset = blockIdx.x * ld; + int32_t const outoffset = out_seqPos * ld; + + // the output offset is given by b * (S*hidden_size) + s * hidden_size + + for (int32_t it = threadIdx.x; it < ld; it += TPB) { + T p(mIdsEmbDev0[poffset + it]); // pos id + T val = p; + int32_t const offset0 = word_id[0] * ld; + val += mIdsEmbDev1[offset0 + it]; + int32_t const offset1 = word_id[1] * ld; + val += mIdsEmbDev2[offset1 + it]; + output[outoffset + it] = val; + T const rldval = rld * val; + threadData = pairSum(threadData, kvp(rldval, rldval * val)); + } + // 3. layer norm on the sum + layerNorm(threadData, ld, outoffset, beta, gamma, output); + } else if (s == s_b) { + for (int32_t it = threadIdx.x; it < ld; it += TPB) { + T val = dense_vector[dense_vector_offset + it]; + output[prompt_tuning_offset + it] = val; + T const rldval = rld * val; + threadData = pairSum(threadData, kvp(rldval, rldval * val)); + // 3. layer norm on the sum + } + layerNorm( + threadData, ld, prompt_tuning_offset, beta, gamma, output); + + } else { + return; // This CTA has nothing to do + } +} + +template +int32_t prompt_tuning_emb(cudaStream_t stream, + int32_t ld, + int32_t B, + int32_t S, + int const* inputIds0, + int const* inputIds1, + int const* inputIds2, + T const* dense_vector, + int32_t nbLookupTables, + float const* beta, + float const* gamma, + T const* mIdsEmbDev0, + T const* mIdsEmbDev1, + T const* mIdsEmbDev2, + int32_t IdsSize0, + int32_t IdsSize1, + int32_t IdsSize2, + T* output, + int32_t* new_pos_id) { + constexpr int32_t tpb = 256; + dim3 const grid(S, B, 1); + dim3 const block(tpb, 1, 1); + size_t cache_size = sizeof(int32_t) * (nbLookupTables - 1); + prompt_tuning_embKernel + <<>>(B, + ld, + inputIds0, + inputIds1, + inputIds2, + dense_vector, + beta, + gamma, + mIdsEmbDev0, + mIdsEmbDev1, + mIdsEmbDev2, + IdsSize0, + IdsSize1, + IdsSize2, + output, + new_pos_id); + return cudaPeekAtLastError(); +} + +template int32_t prompt_tuning_emb(cudaStream_t, + int32_t, + int32_t, + int32_t, + int32_t const*, + int32_t const*, + int32_t const*, + half const*, + int32_t, + float const*, + float const*, + half const*, + half const*, + half const*, + int32_t, + int32_t, + int32_t, + half*, + int32_t*); + +} // namespace plugin +} // namespace tensorrt +} // namespace inference +} // namespace paddle diff --git a/paddle/fluid/inference/tensorrt/plugin/prompt_tuning_emb_layernorm_varseqlen_plugin.cu b/paddle/fluid/inference/tensorrt/plugin/prompt_tuning_emb_layernorm_varseqlen_plugin.cu new file mode 100644 index 0000000000000..64fde0785fdc7 --- /dev/null +++ b/paddle/fluid/inference/tensorrt/plugin/prompt_tuning_emb_layernorm_varseqlen_plugin.cu @@ -0,0 +1,562 @@ +// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +// SPDX-FileCopyrightText: Copyright (c) 1993-2023 NVIDIA CORPORATION & +// AFFILIATES. All rights reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/fluid/inference/tensorrt/plugin/prompt_tuning_emb_layernorm_varseqlen_plugin.h" +#include +#include +#include +#include "NvInfer.h" + +namespace paddle { +namespace inference { +namespace tensorrt { +namespace plugin { + +constexpr size_t threadsPerCta128 = 2 * 2 * 32; +constexpr size_t threadsPerCta256 = 1 * 4 * 32; +constexpr size_t threadsPerCta384 = 1 * 8 * 32; +// The number of xmmas in the M dimension. We use one uint32_t per XMMA in the M +// dimension: (s + 16*warps_m - 1) / (16*warps_m); +constexpr size_t xmmasM128 = 4; +constexpr size_t xmmasM256 = 16; +constexpr size_t xmmasM384 = 24; +// Packed mask size per batch. Layout is XMMAS_M * THREADS_PER_CTA. +constexpr size_t packedMaskSize128 = xmmasM128 * threadsPerCta128; +constexpr size_t packedMaskSize256 = xmmasM256 * threadsPerCta256; +constexpr size_t packedMaskSize384 = xmmasM384 * threadsPerCta384; +char const* EMB_LAYER_NORM_VAR_SEQLEN_VERSION_HFACE{"1"}; +char const* EMB_LAYER_NORM_VAR_SEQLEN_NAME{ + "PromptTuningEmbLayerNormVarlenPluginDynamic"}; +// Static class fields initialization +nvinfer1::PluginFieldCollection + TrtPromptTuningEmbLayerNormVarSeqlenPluginBaseCreator::mFC{}; +std::vector + TrtPromptTuningEmbLayerNormVarSeqlenPluginBaseCreator::mPluginAttributes; + +TrtPromptTuningEmbLayerNormVarSeqlenPluginBase:: + TrtPromptTuningEmbLayerNormVarSeqlenPluginBase( + std::string const& name, + nvinfer1::DataType const type, + nvinfer1::Weights const& beta, + nvinfer1::Weights const& gamma, + const std::vector& IdsEmb) + : mLayerName(name), + mLd(beta.count), + mType(type), + mIdsEmb_(IdsEmb), + nbLookupTables_(static_cast(IdsEmb.size())) { + // Assuming Weights.count is the number of elements and not bytes + assert(beta.count == gamma.count); + mBeta.convertAndCopy(beta, nvinfer1::DataType::kFLOAT); + mGamma.convertAndCopy(gamma, nvinfer1::DataType::kFLOAT); + copyToDevice(&mGamma, sizeof(float) * mGamma.count, &mGammaDev); + copyToDevice(&mBeta, sizeof(float) * mBeta.count, &mBetaDev); + for (size_t i = 0; i < mIdsEmb_.size(); ++i) { + assert(mIdsEmb_[i].count % mLd == 0); + mIdsVocabSize.push_back(int32_t(mIdsEmb_[i].count / mLd)); + WeightsWithOwnership tem_weight; + tem_weight.convertAndCopy(mIdsEmb_[i], mType); + void* cudaMem{nullptr}; + PADDLE_ENFORCE_GPU_SUCCESS( + cudaMalloc(&cudaMem, getWeightsSize(tem_weight, mType))); + PADDLE_ENFORCE_GPU_SUCCESS(cudaMemcpy(cudaMem, + tem_weight.values, + getWeightsSize(tem_weight, mType), + cudaMemcpyHostToDevice)); + mIdsEmbPtrs.push_back(cudaMem); + } +} + +TrtPromptTuningEmbLayerNormVarSeqlenPluginBase:: + TrtPromptTuningEmbLayerNormVarSeqlenPluginBase(std::string const& name, + void const* data, + size_t length) + : mLayerName(name), + mGammaDev(nullptr), + mBetaDev(nullptr), + mIdsEmbPtrs{}, + mIdsEmb_{} { + // Deserialize in the same order as serialization + deserialize_value(&data, &length, &mType); + deserialize_value(&data, &length, &mLd); + deserialize_value(&data, &length, &nbLookupTables_); + for (int32_t i = 0; i < nbLookupTables_; ++i) { + int32_t tem; + deserialize_value(&data, &length, &tem); + mIdsVocabSize.push_back(tem); + } + char const* d = static_cast(data); + mBeta.convertAndCopy(&d, mLd, nvinfer1::DataType::kFLOAT); + mGamma.convertAndCopy(&d, mLd, nvinfer1::DataType::kFLOAT); + for (int32_t i = 0; i < nbLookupTables_; ++i) { + nvinfer1::Weights pre_tem_weight; + pre_tem_weight.type = mType; + pre_tem_weight.count = mLd * size_t(mIdsVocabSize[i]); + const auto nbBytes = mLd * size_t(mIdsVocabSize[i]) * getElementSize(mType); + auto destBuf = new char[nbBytes]; + pre_tem_weight.values = destBuf; + std::copy_n(d, nbBytes, destBuf); + d += nbBytes; + mIdsEmb_.push_back(pre_tem_weight); + } +} + +TrtPromptTuningEmbLayerNormVarSeqlenPluginHFace:: + TrtPromptTuningEmbLayerNormVarSeqlenPluginHFace( + std::string const& name, + nvinfer1::DataType const type, + nvinfer1::Weights const& beta, + nvinfer1::Weights const& gamma, + const std::vector& IdsEmb) + : TrtPromptTuningEmbLayerNormVarSeqlenPluginBase( + name, type, beta, gamma, IdsEmb) {} + +TrtPromptTuningEmbLayerNormVarSeqlenPluginHFace:: + TrtPromptTuningEmbLayerNormVarSeqlenPluginHFace(std::string const& name, + void const* data, + size_t length) + : TrtPromptTuningEmbLayerNormVarSeqlenPluginBase(name, data, length) { + TRANSFORMER_DEBUG_MSG( + "TrtPromptTuningEmbLayerNormVarSeqlenPluginHFace deserialize"); +} + +// IPluginV2DynamicExt Methods +nvinfer1::IPluginV2DynamicExt* +TrtPromptTuningEmbLayerNormVarSeqlenPluginHFace::clone() const noexcept { + TRANSFORMER_DEBUG_MSG( + "TrtPromptTuningEmbLayerNormVarSeqlenPluginHFace clone"); + auto p = new TrtPromptTuningEmbLayerNormVarSeqlenPluginHFace( + mLayerName, mType, mBeta, mGamma, mIdsEmb_); + p->setPluginNamespace(mNamespace.c_str()); + return p; +} + +nvinfer1::DimsExprs +TrtPromptTuningEmbLayerNormVarSeqlenPluginHFace::getOutputDimensions( + int32_t outputIndex, + nvinfer1::DimsExprs const* inputs, + int32_t nbInputs, + nvinfer1::IExprBuilder& exprBuilder) noexcept { + for (int i = 1; i < nbInputs - 2; ++i) { + assert(inputs[i].nbDims == 1); // seq length + assert(inputs[i].nbDims == inputs[1].nbDims); // same shape + } + assert(inputs[0].nbDims == 1); // pos_id: B+1 + auto one = exprBuilder.constant(1); + auto Bplus1 = inputs[0].d[0]; // pos_id + auto B = + exprBuilder.operation(nvinfer1::DimensionOperation::kSUB, *Bplus1, *one); + if (outputIndex == 0) { + nvinfer1::DimsExprs ret; + ret.nbDims = 4; + ret.d[0] = exprBuilder.operation(nvinfer1::DimensionOperation::kSUM, + *inputs[1].d[0], + *B); // sum of seq length + ret.d[1] = exprBuilder.constant(mLd); + ret.d[2] = exprBuilder.constant(1); + ret.d[3] = exprBuilder.constant(1); + return ret; + } else if (outputIndex == 1) { + // This is a hack: we just report some mask size and rely the plugins to + // play nicely together. + // At runtime, depending on the actual maxSeqlen, the size might be + // different. + int32_t maskSize_ = packedMaskSize384; + auto maskSize = exprBuilder.constant(maskSize_); + auto fp16maskSize = + exprBuilder.operation(nvinfer1::DimensionOperation::kPROD, + *maskSize, + *exprBuilder.constant(2)); + nvinfer1::DimsExprs ret; + ret.nbDims = 2; + ret.d[0] = B; + ret.d[1] = fp16maskSize; + return ret; + } else if (outputIndex == 2) { + nvinfer1::DimsExprs ret; + ret.nbDims = 1; + ret.d[0] = exprBuilder.operation(nvinfer1::DimensionOperation::kSUM, + *inputs[nbInputs - 2].d[1], + *one); // max seqlen + return ret; + } else if (outputIndex == 3) { + nvinfer1::DimsExprs ret = inputs[nbInputs - 2]; // new mask_id + ret.d[1] = exprBuilder.operation( + nvinfer1::DimensionOperation::kSUM, *inputs[nbInputs - 2].d[1], *one); + return ret; + } else if (outputIndex == 4) { + nvinfer1::DimsExprs ret = inputs[0]; // new pos_id + return ret; + } +} + +bool TrtPromptTuningEmbLayerNormVarSeqlenPluginBase::supportsFormatCombination( + int32_t pos, + nvinfer1::PluginTensorDesc const* inOut, + int32_t nbInputs, + int32_t nbOutputs) noexcept { + assert(nbOutputs == 5); + nvinfer1::PluginTensorDesc const& desc = inOut[pos]; + if (desc.format != nvinfer1::TensorFormat::kLINEAR) { + return false; + } + if (pos == 0) { // pos_id + return desc.dims.nbDims == 1 && desc.type == nvinfer1::DataType::kINT32; + } + if (pos == 1) { // input_id + return desc.dims.nbDims == 1 && desc.type == nvinfer1::DataType::kINT32; + } + nvinfer1::PluginTensorDesc const& prev = inOut[1]; // input_ids + if (1 < pos && + pos < (nbInputs - 2)) { // other ids: check it's the same as input_ids + return desc.type == prev.type && desc.dims.nbDims == 1 && + desc.dims.d[0] == prev.dims.d[0]; + } + if (pos == nbInputs - 2) { // mask id + return desc.type == mType; + } + if (pos == nbInputs - 1) { // dense vector + return desc.type == mType; + } + // embedded sequence + if (pos == nbInputs) { + return desc.type == mType && desc.dims.nbDims == 4 && desc.dims.d[2] == 1 && + desc.dims.d[3] == 1; + } + // mask(HFace) + if (pos == nbInputs + 1) { + return desc.type == mType; + } + // max seqlen + if (pos == nbInputs + 2) { + return desc.type == mType; + } + // new mask_id + if (pos == nbInputs + 3) { + return desc.type == mType; + } + // new pos_id + if (pos == nbInputs + 4) { + return desc.dims.nbDims == 1 && desc.type == nvinfer1::DataType::kINT32; + } +} + +void checkConfigurationInputs(nvinfer1::DynamicPluginTensorDesc const* inputs, + int32_t nbInputs, + nvinfer1::DynamicPluginTensorDesc const* outputs, + int32_t nbOutputs) noexcept { + // Validate input arguments + assert(nbOutputs == 5); + assert(inputs[0].desc.dims.nbDims == 1); + assert(inputs[0].desc.type == nvinfer1::DataType::kINT32); + for (int i = 1; i < nbInputs - 2; ++i) { + assert(inputs[i].desc.dims.nbDims == 1); + assert(inputs[i].desc.dims.d[0] == inputs[1].desc.dims.d[0]); + assert(inputs[i].desc.type == nvinfer1::DataType::kINT32); + } + assert(outputs[0].desc.dims.nbDims == 4); + assert(outputs[0].desc.dims.d[2] == 1); + assert(outputs[0].desc.dims.d[3] == 1); +} + +void TrtPromptTuningEmbLayerNormVarSeqlenPluginHFace::configurePlugin( + nvinfer1::DynamicPluginTensorDesc const* inputs, + int32_t nbInputs, + nvinfer1::DynamicPluginTensorDesc const* outputs, + int32_t nbOutputs) noexcept { + TRANSFORMER_DEBUG_MSG( + "TrtPromptTuningEmbLayerNormVarSeqlenPluginHFace configurePlugin"); + checkConfigurationInputs(inputs, nbInputs, outputs, nbOutputs); + assert(static_cast(outputs[0].desc.dims.d[1]) == + static_cast(mLd)); + int32_t const B = inputs[0].desc.dims.d[0] - 1; + // check mask + assert(outputs[1].desc.dims.nbDims == 2); + if (B > 0) { + assert(outputs[1].desc.dims.d[0] == B); + } + assert((outputs[1].desc.dims.d[1] == 2 * packedMaskSize384) || + (outputs[1].desc.dims.d[1] == 2 * packedMaskSize128) || + (outputs[1].desc.dims.d[1] == 2 * packedMaskSize256)); + assert(outputs[0].desc.type == mType); + assert(outputs[1].desc.type == nvinfer1::DataType::kHALF); +} + +size_t TrtPromptTuningEmbLayerNormVarSeqlenPluginBase::getWorkspaceSize( + nvinfer1::PluginTensorDesc const* inputs, + int32_t nbInputs, + nvinfer1::PluginTensorDesc const* outputs, + int32_t nbOutputs) const noexcept { + return 0; +} + +int32_t TrtPromptTuningEmbLayerNormVarSeqlenPluginHFace::enqueue( + nvinfer1::PluginTensorDesc const* inputDesc, + nvinfer1::PluginTensorDesc const* outputDesc, + void const* const* inputs, + void* const* outputs, + void* workspace, + cudaStream_t stream) noexcept { + int32_t batchSize = inputDesc[0].dims.d[0] - 1; + // read out the maximum sequence length from the dummy input + int32_t const maxSeqlen = inputDesc[nbLookupTables_].dims.d[1] + 1; + int32_t S; + if (maxSeqlen <= 128) { + S = 128; + } else if (maxSeqlen <= 192) { + S = 192; + } else if (maxSeqlen <= 256) { + S = 256; + } else if (maxSeqlen <= 384) { + S = 384; + } else if (maxSeqlen <= 512) { + S = 512; + } else { + PADDLE_THROW(platform::errors::Fatal("The max seqlenth is 512.")); + } + const float* beta = mBetaDev.get(); + const float* gamma = mGammaDev.get(); + + auto output = static_cast(outputs[0]); + auto new_pos_id = static_cast(outputs[4]); + return prompt_tuning_emb(stream, + static_cast(mLd), + batchSize, + S, + static_cast(inputs[0]), + static_cast(inputs[1]), + static_cast(inputs[2]), + static_cast(inputs[4]), + nbLookupTables_, + beta, + gamma, + static_cast(mIdsEmbPtrs[0]), + static_cast(mIdsEmbPtrs[1]), + static_cast(mIdsEmbPtrs[2]), + mIdsVocabSize[0], + mIdsVocabSize[1], + mIdsVocabSize[2], + output, + new_pos_id); +} + +// IPluginV2Ext Methods +nvinfer1::DataType +TrtPromptTuningEmbLayerNormVarSeqlenPluginBase::getOutputDataType( + int32_t index, + nvinfer1::DataType const* inputTypes, + int32_t nbInputs) const noexcept { + assert(mType == nvinfer1::DataType::kHALF); + if (index == 0) { + return mType; + } else if (index == 1) { + return mType; + } else if (index == 2) { + return mType; + } else if (index == 3) { + return mType; + } else if (index == 4) { + return nvinfer1::DataType::kINT32; + } +} + +// IPluginV2 Methods +char const* TrtPromptTuningEmbLayerNormVarSeqlenPluginBase::getPluginType() + const noexcept { + return EMB_LAYER_NORM_VAR_SEQLEN_NAME; +} + +char const* TrtPromptTuningEmbLayerNormVarSeqlenPluginHFace::getPluginVersion() + const noexcept { + return EMB_LAYER_NORM_VAR_SEQLEN_VERSION_HFACE; +} + +int32_t TrtPromptTuningEmbLayerNormVarSeqlenPluginBase::getNbOutputs() + const noexcept { + return 5; +} + +int32_t TrtPromptTuningEmbLayerNormVarSeqlenPluginHFace::initialize() noexcept { + TRANSFORMER_DEBUG_MSG( + "TrtPromptTuningEmbLayerNormVarSeqlenPluginHFace initialize"); + return 0; +} + +void TrtPromptTuningEmbLayerNormVarSeqlenPluginHFace::terminate() noexcept { + TRANSFORMER_DEBUG_MSG( + "TrtPromptTuningEmbLayerNormVarSeqlenPluginHFace terminate"); +} + +size_t TrtPromptTuningEmbLayerNormVarSeqlenPluginBase::getSerializationSize() + const noexcept { + size_t const wordSize = getElementSize(mType); + return 2 * sizeof(float) * mLd // beta + gamma + + sizeof(mType) // + + sizeof(mLd) // + + mIdsVocabSize.size() * sizeof(mIdsVocabSize[0]) // + + wordSize * mLd * + accumulate( + mIdsVocabSize.begin(), mIdsVocabSize.end(), 0) // ids emb + + sizeof(nbLookupTables_); // numbers of lookup_table +} + +void TrtPromptTuningEmbLayerNormVarSeqlenPluginBase::serialize( + void* buffer) const noexcept { + serialize_value(&buffer, mType); + serialize_value(&buffer, mLd); + serialize_value(&buffer, nbLookupTables_); + for (size_t i = 0; i < mIdsVocabSize.size(); ++i) { + serialize_value(&buffer, mIdsVocabSize[i]); + } + char* d = static_cast(buffer); + size_t const wordSize = getElementSize(mType); + serFromDev(&d, mBetaDev.get(), mLd); + serFromDev(&d, mGammaDev.get(), mLd); + for (size_t i = 0; i < mIdsEmbPtrs.size(); ++i) { + serFromDev(&d, + static_cast(mIdsEmbPtrs[i]), + mLd * mIdsVocabSize[i] * wordSize); + } +} + +void TrtPromptTuningEmbLayerNormVarSeqlenPluginBase::destroy() noexcept { + // This gets called when the network containing plugin is destroyed + mBetaDev.reset(nullptr); + mGammaDev.reset(nullptr); + for (size_t i = 0; i < mIdsEmbPtrs.size(); ++i) { + cudaFree(mIdsEmbPtrs[i]); + } + delete this; +} + +void TrtPromptTuningEmbLayerNormVarSeqlenPluginHFace::destroy() noexcept { + TRANSFORMER_DEBUG_MSG( + "TrtPromptTuningEmbLayerNormVarSeqlenPluginHFace destroy"); + TrtPromptTuningEmbLayerNormVarSeqlenPluginBase::destroy(); +} + +void TrtPromptTuningEmbLayerNormVarSeqlenPluginBase::setPluginNamespace( + char const* libNamespace) noexcept { + mNamespace = libNamespace; +} + +char const* TrtPromptTuningEmbLayerNormVarSeqlenPluginBase::getPluginNamespace() + const noexcept { + return mNamespace.c_str(); +} + +TrtPromptTuningEmbLayerNormVarSeqlenPluginBaseCreator:: + TrtPromptTuningEmbLayerNormVarSeqlenPluginBaseCreator() = default; + +char const* +TrtPromptTuningEmbLayerNormVarSeqlenPluginBaseCreator::getPluginName() + const noexcept { + return EMB_LAYER_NORM_VAR_SEQLEN_NAME; +} + +char const* +TrtPromptTuningEmbLayerNormVarSeqlenPluginHFaceCreator::getPluginVersion() + const noexcept { + return EMB_LAYER_NORM_VAR_SEQLEN_VERSION_HFACE; +} + +nvinfer1::PluginFieldCollection const* +TrtPromptTuningEmbLayerNormVarSeqlenPluginBaseCreator:: + getFieldNames() noexcept { + return &mFC; +} + +bool initializeFields(nvinfer1::PluginFieldCollection const* fc, + nvinfer1::Weights* beta, + nvinfer1::Weights* gamma, + std::vector* IdsEmb) { + bool output_fp16 = false; + for (int32_t i = 0; i < fc->nbFields; i++) { + std::string field_name(fc->fields[i].name); + if (field_name.compare("bert_embeddings_layernorm_beta") == 0) { + TRANSFORMER_DEBUG_MSG("Building bert_embeddings_layernorm_beta..."); + beta->values = fc->fields[i].data; + beta->count = fc->fields[i].length; + beta->type = fieldTypeToDataType(fc->fields[i].type); + } + + if (field_name.compare("bert_embeddings_layernorm_gamma") == 0) { + TRANSFORMER_DEBUG_MSG("Building bert_embeddings_layernorm_gamma..."); + gamma->values = fc->fields[i].data; + gamma->count = fc->fields[i].length; + gamma->type = fieldTypeToDataType(fc->fields[i].type); + } + + if (field_name.compare("output_fp16") == 0) { + TRANSFORMER_DEBUG_MSG("Building output_fp16..."); + assert(fc->fields[i].type == nvinfer1::PluginFieldType::kINT32); + output_fp16 = static_cast(fc->fields[i].data)[0] != 0; + } + if (field_name.compare("bert_embeddings_word_embeddings_" + + std::to_string(i - 3)) == 0) { + TRANSFORMER_DEBUG_MSG( + ("bert_embeddings_word_embeddings_" + std::to_string(i - 3)).c_str()); + nvinfer1::Weights tem; + tem.values = fc->fields[i].data; + tem.count = fc->fields[i].length; + tem.type = fieldTypeToDataType(fc->fields[i].type); + IdsEmb->push_back(tem); + } + } + return output_fp16; +} + +nvinfer1::IPluginV2* +TrtPromptTuningEmbLayerNormVarSeqlenPluginHFaceCreator::createPlugin( + char const* name, nvinfer1::PluginFieldCollection const* fc) noexcept { + TRANSFORMER_DEBUG_MSG("EmbLayerNormVarSeqlenHFace createPlugin"); + nvinfer1::Weights beta; + nvinfer1::Weights gamma; + std::vector IdsEmb; + bool output_fp16 = initializeFields(fc, &beta, &gamma, &IdsEmb); + TRANSFORMER_DEBUG_MSG("Building the Plugin..."); + TrtPromptTuningEmbLayerNormVarSeqlenPluginHFace* p = + new TrtPromptTuningEmbLayerNormVarSeqlenPluginHFace( + name, + output_fp16 ? nvinfer1::DataType::kHALF : nvinfer1::DataType::kFLOAT, + beta, + gamma, + IdsEmb); + return p; +} + +nvinfer1::IPluginV2* +TrtPromptTuningEmbLayerNormVarSeqlenPluginHFaceCreator::deserializePlugin( + char const* name, void const* serialData, size_t serialLength) noexcept { + return new TrtPromptTuningEmbLayerNormVarSeqlenPluginHFace( + name, serialData, serialLength); +} + +void TrtPromptTuningEmbLayerNormVarSeqlenPluginBaseCreator::setPluginNamespace( + char const* libNamespace) noexcept { + mNamespace = libNamespace; +} + +char const* +TrtPromptTuningEmbLayerNormVarSeqlenPluginBaseCreator::getPluginNamespace() + const noexcept { + return mNamespace.c_str(); +} + +} // namespace plugin +} // namespace tensorrt +} // namespace inference +} // namespace paddle diff --git a/paddle/fluid/inference/tensorrt/plugin/prompt_tuning_emb_layernorm_varseqlen_plugin.h b/paddle/fluid/inference/tensorrt/plugin/prompt_tuning_emb_layernorm_varseqlen_plugin.h new file mode 100644 index 0000000000000..b479a992bbb5d --- /dev/null +++ b/paddle/fluid/inference/tensorrt/plugin/prompt_tuning_emb_layernorm_varseqlen_plugin.h @@ -0,0 +1,189 @@ +// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +// SPDX-FileCopyrightText: Copyright (c) 1993-2023 NVIDIA CORPORATION & +// AFFILIATES. All rights reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once +#include +#include "NvInferPlugin.h" +#include "NvInferRuntime.h" + +#include "paddle/fluid/inference/tensorrt/plugin/common/bertCommon.h" +#include "paddle/fluid/inference/tensorrt/plugin/common/plugin.h" +#include "paddle/fluid/inference/tensorrt/plugin/common/serialize.h" +#include "paddle/fluid/inference/tensorrt/plugin/trt_plugin.h" +#include "paddle/fluid/platform/enforce.h" + +namespace paddle { +namespace inference { +namespace tensorrt { +namespace plugin { + +template +int32_t prompt_tuning_emb(cudaStream_t, + int32_t, + int32_t, + int32_t, + int32_t const*, + int32_t const*, + int32_t const*, + T const*, + int32_t, + float const*, + float const*, + T const*, + T const*, + T const*, + int32_t, + int32_t, + int32_t, + T*, + int32_t*); +class TrtPromptTuningEmbLayerNormVarSeqlenPluginBase + : public nvinfer1::IPluginV2DynamicExt { + public: + TrtPromptTuningEmbLayerNormVarSeqlenPluginBase( + std::string const& name, + nvinfer1::DataType const type, + nvinfer1::Weights const& beta, + nvinfer1::Weights const& gamma, + const std::vector& ids_emb); + + TrtPromptTuningEmbLayerNormVarSeqlenPluginBase(std::string const& name, + void const* data, + size_t length); + + // It doesn't make sense to make TrtPromptTuningEmbLayerNormVarSeqlenPlugin + // without arguments, so we delete default constructor. + TrtPromptTuningEmbLayerNormVarSeqlenPluginBase() = delete; + + // IPluginV2DynamicExt Methods + bool supportsFormatCombination(int32_t pos, + nvinfer1::PluginTensorDesc const* inOut, + int32_t nbInputs, + int32_t nbOutputs) noexcept override; + size_t getWorkspaceSize(nvinfer1::PluginTensorDesc const* inputs, + int32_t nbInputs, + nvinfer1::PluginTensorDesc const* outputs, + int32_t nbOutputs) const noexcept override; + + // IPluginV2Ext Methods + nvinfer1::DataType getOutputDataType( + int32_t index, + nvinfer1::DataType const* inputTypes, + int32_t nbInputs) const noexcept override; + + // IPluginV2 Methods + char const* getPluginType() const noexcept override; + int32_t getNbOutputs() const noexcept override; + size_t getSerializationSize() const noexcept override; + void serialize(void* buffer) const noexcept override; + void destroy() noexcept override; + char const* getPluginNamespace() const noexcept override; + void setPluginNamespace(char const* pluginNamespace) noexcept override; + + protected: + std::string const mLayerName; + std::string mNamespace; + cuda_unique_ptr mGammaDev; + cuda_unique_ptr mBetaDev; + std::vector mIdsEmbPtrs; + size_t mLd; // leading dim = hidden size + std::vector mIdsVocabSize; + WeightsWithOwnership mBeta; + WeightsWithOwnership mGamma; + nvinfer1::DataType mType; + std::vector mIdsEmb_; + int32_t nbLookupTables_ = 0; +}; + +class TrtPromptTuningEmbLayerNormVarSeqlenPluginHFace + : public TrtPromptTuningEmbLayerNormVarSeqlenPluginBase { + public: + TrtPromptTuningEmbLayerNormVarSeqlenPluginHFace( + std::string const& name, + nvinfer1::DataType const type, + nvinfer1::Weights const& beta, + nvinfer1::Weights const& gamma, + const std::vector& ids_emb); + + TrtPromptTuningEmbLayerNormVarSeqlenPluginHFace(std::string const& name, + void const* data, + size_t length); + + // It doesn't make sense to make TrtPromptTuningEmbLayerNormVarSeqlenPlugin + // without arguments, so we delete default constructor. + TrtPromptTuningEmbLayerNormVarSeqlenPluginHFace() = delete; + + // IPluginV2DynamicExt Methods + nvinfer1::IPluginV2DynamicExt* clone() const noexcept override; + nvinfer1::DimsExprs getOutputDimensions( + int32_t outputIndex, + const nvinfer1::DimsExprs* inputs, + int32_t nbInputs, + nvinfer1::IExprBuilder& exprBuilder) noexcept override; + void configurePlugin(nvinfer1::DynamicPluginTensorDesc const* in, + int32_t nbInputs, + nvinfer1::DynamicPluginTensorDesc const* out, + int32_t nbOutputs) noexcept override; + int32_t enqueue(nvinfer1::PluginTensorDesc const* inputDesc, + nvinfer1::PluginTensorDesc const* outputDesc, + void const* const* inputs, + void* const* outputs, + void* workspace, + cudaStream_t stream) noexcept override; + // IPluginV2 Methods + int32_t initialize() noexcept override; + void terminate() noexcept override; + void destroy() noexcept override; + char const* getPluginVersion() const noexcept override; +}; + +class TrtPromptTuningEmbLayerNormVarSeqlenPluginBaseCreator + : public nvinfer1::IPluginCreator { + public: + TrtPromptTuningEmbLayerNormVarSeqlenPluginBaseCreator(); + + char const* getPluginName() const noexcept override; + + const nvinfer1::PluginFieldCollection* getFieldNames() noexcept override; + + void setPluginNamespace(char const* pluginNamespace) noexcept override; + + char const* getPluginNamespace() const noexcept override; + + protected: + static nvinfer1::PluginFieldCollection mFC; + static std::vector mPluginAttributes; + std::string mNamespace; +}; + +class TrtPromptTuningEmbLayerNormVarSeqlenPluginHFaceCreator + : public TrtPromptTuningEmbLayerNormVarSeqlenPluginBaseCreator { + public: + nvinfer1::IPluginV2* createPlugin( + char const* name, + const nvinfer1::PluginFieldCollection* fc) noexcept override; + char const* getPluginVersion() const noexcept override; + nvinfer1::IPluginV2* deserializePlugin(char const* name, + void const* serialData, + size_t serialLength) noexcept override; +}; + +REGISTER_TRT_PLUGIN_V2(TrtPromptTuningEmbLayerNormVarSeqlenPluginHFaceCreator); + +} // namespace plugin +} // namespace tensorrt +} // namespace inference +} // namespace paddle diff --git a/paddle/fluid/ir/dialect/CMakeLists.txt b/paddle/fluid/ir/dialect/CMakeLists.txt deleted file mode 100644 index 7500642867f30..0000000000000 --- a/paddle/fluid/ir/dialect/CMakeLists.txt +++ /dev/null @@ -1,2 +0,0 @@ -add_subdirectory(paddle_dialect) -add_subdirectory(paddle_kernel_dialect) diff --git a/paddle/fluid/ir/dialect/paddle_dialect/interface/CMakeLists.txt b/paddle/fluid/ir/dialect/paddle_dialect/interface/CMakeLists.txt deleted file mode 100644 index 5ee2f3510ca93..0000000000000 --- a/paddle/fluid/ir/dialect/paddle_dialect/interface/CMakeLists.txt +++ /dev/null @@ -1,7 +0,0 @@ -# All source files of pd_dialect, except for the source file of op, which is generated in the compilation directory. -file(GLOB PD_INTERFACE_SRCS "*.cc") - -cc_library( - pd_interface - SRCS ${PD_INTERFACE_SRCS} - DEPS ir_core phi_utils) diff --git a/paddle/fluid/ir/dialect/paddle_dialect/ir/pd_manual_api.cc b/paddle/fluid/ir/dialect/paddle_dialect/ir/pd_manual_api.cc deleted file mode 100644 index b95d78a74f470..0000000000000 --- a/paddle/fluid/ir/dialect/paddle_dialect/ir/pd_manual_api.cc +++ /dev/null @@ -1,43 +0,0 @@ -// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#include "paddle/fluid/ir/dialect/paddle_dialect/ir/pd_manual_api.h" -#include "paddle/fluid/ir/dialect/paddle_dialect/ir/api_builder.h" -#include "paddle/fluid/ir/dialect/paddle_dialect/ir/pd_op.h" -#include "paddle/ir/core/builtin_op.h" - -namespace paddle { -namespace dialect { -ir::OpResult split_grad(std::vector out_grads, - ir::OpResult axis) { - auto combine_op = - APIBuilder::Instance().GetBuilder()->Build(out_grads); - paddle::dialect::SplitGradOp split_grad_op = - APIBuilder::Instance().GetBuilder()->Build( - combine_op.out(), axis); - - return split_grad_op.x_grad(); -} - -ir::OpResult split_grad(std::vector out_grads, int axis) { - auto combine_op = - APIBuilder::Instance().GetBuilder()->Build(out_grads); - paddle::dialect::SplitGradOp split_grad_op = - APIBuilder::Instance().GetBuilder()->Build( - combine_op.out(), axis); - - return split_grad_op.x_grad(); -} -} // namespace dialect -} // namespace paddle diff --git a/paddle/fluid/ir/dialect/paddle_dialect/ir/pd_manual_op.h b/paddle/fluid/ir/dialect/paddle_dialect/ir/pd_manual_op.h deleted file mode 100644 index c8a5e1658ec4d..0000000000000 --- a/paddle/fluid/ir/dialect/paddle_dialect/ir/pd_manual_op.h +++ /dev/null @@ -1,204 +0,0 @@ -// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#ifdef GET_MANUAL_OP_LIST -#undef GET_MANUAL_OP_LIST -paddle::dialect::AddNOp, paddle::dialect::SplitGradOp, paddle::dialect::IfOp - -#else - -#pragma once -#include - -#include "paddle/fluid/framework/infershape_utils.h" -#include "paddle/fluid/ir/dialect/paddle_dialect/interface/infermeta.h" -#include "paddle/fluid/ir/dialect/paddle_dialect/interface/op_yaml_info.h" -#include "paddle/fluid/ir/dialect/paddle_dialect/trait/inplace.h" -#include "paddle/fluid/ir/dialect/paddle_dialect/utils/op_yaml_info_util.h" -#include "paddle/fluid/ir/dialect/paddle_dialect/utils/utils.h" -#include "paddle/ir/core/builder.h" -#include "paddle/ir/core/ir_printer.h" -#include "paddle/ir/core/op_base.h" -#include "paddle/ir/core/operation_utils.h" -#include "paddle/phi/core/infermeta_utils.h" - -namespace paddle { -namespace dialect { - -class AddNOp : public ir::Op { - public: - using Op::Op; - static const char *name() { return "pd.add_n"; } - static constexpr const char **attributes_name = nullptr; - static constexpr uint32_t attributes_num = 0; - static OpInfoTuple GetOpInfo(); - static void Build(ir::Builder &builder, // NOLINT - ir::OperationArgument &argument, // NOLINT - ir::OpResult inputs); - - void Verify(); - ir::Value inputs() { return operand_source(0); } - ir::OpResult out() { return result(0); } - static void InferMeta(phi::InferMetaContext *infer_meta); -}; - -class AddN_Op : public ir::Op { - public: - using Op::Op; - static const char *name() { return "pd.add_n_"; } - static constexpr const char **attributes_name = nullptr; - static constexpr uint32_t attributes_num = 0; - static OpInfoTuple GetOpInfo(); - static void Build(ir::Builder &builder, // NOLINT - ir::OperationArgument &argument, // NOLINT - ir::OpResult inputs_); - - void Verify(); - ir::Value inputs() { return operand_source(0); } - ir::OpResult out() { return result(0); } - - static void InferMeta(phi::InferMetaContext *infer_meta); -}; - -class AddNWithKernelOp : public ir::Op { - public: - using Op::Op; - static const char *name() { return "pd.add_n_with_kernel"; } - static constexpr const char **attributes_name = nullptr; - static constexpr uint32_t attributes_num = 0; - static OpInfoTuple GetOpInfo(); - static void Build(ir::Builder &builder, // NOLINT - ir::OperationArgument &argument, // NOLINT - ir::OpResult inputs_); - - void Verify(); - ir::Value inputs() { return operand_source(0); } - ir::OpResult out() { return result(0); } - - static void InferMeta(phi::InferMetaContext *infer_meta); -}; - -class FusedGemmEpilogueOp : public ir::Op { - public: - using Op::Op; - static const char *name() { return "pd.fused_gemm_epilogue"; } - static const char *attributes_name[3]; - static constexpr uint32_t attributes_num = 3; - static OpInfoTuple GetOpInfo(); - - static void Build(ir::Builder &builder, // NOLINT - ir::OperationArgument &argument, // NOLINT - ir::OpResult x_, - ir::OpResult y_, - ir::OpResult bias_, - ir::AttributeMap attributes); - void Verify(); - ir::Value x() { return operand_source(0); } - ir::Value y() { return operand_source(1); } - ir::Value bias() { return operand_source(2); } - ir::OpResult out() { return result(0); } - ir::OpResult reserve_space() { return result(1); } - - static void InferMeta(phi::InferMetaContext *infer_meta); -}; - -class FusedGemmEpilogueGradOp - : public ir::Op { - public: - using Op::Op; - static const char *name() { return "pd.fused_gemm_epilogue_grad"; } - static const char *attributes_name[3]; - static constexpr uint32_t attributes_num = 3; - static OpInfoTuple GetOpInfo(); - - static void Build(ir::Builder &builder, // NOLINT - ir::OperationArgument &argument, // NOLINT - ir::OpResult x_, - ir::OpResult y_, - ir::OpResult reserve_space_, - ir::OpResult out_grad_, - ir::AttributeMap attributes); - void Verify(); - ir::Value x() { return operand_source(0); } - ir::Value y() { return operand_source(1); } - ir::Value reserve_space() { return operand_source(2); } - ir::Value out_grad() { return operand_source(3); } - ir::OpResult x_grad() { return result(0); } - ir::OpResult y_grad() { return result(1); } - ir::OpResult bias_grad() { return result(2); } - - static void InferMeta(phi::InferMetaContext *infer_meta); -}; - -class SplitGradOp : public ir::Op { - public: - using Op::Op; - static const char *name() { return "pd.split_grad"; } - static const char *attributes_name[1]; - static constexpr uint32_t attributes_num = 1; - static OpInfoTuple GetOpInfo(); - static void Build(ir::Builder &builder, // NOLINT - ir::OperationArgument &argument, // NOLINT - ir::OpResult x_, - float axis = 0); - static void Build(ir::Builder &builder, // NOLINT - ir::OperationArgument &argument, // NOLINT - ir::OpResult out_grad_, - ir::OpResult axis_); - - void Verify(); - ir::Value out_grad() { return operand_source(0); } - ir::Value axis() { return operand_source(1); } - ir::OpResult x_grad() { return result(0); } - static void InferMeta(phi::InferMetaContext *infer_meta); -}; - -class IfOp : public ir::Op { - public: - using Op::Op; - static const char *name() { return "pd.if"; } - static constexpr const char **attributes_name = nullptr; - static constexpr uint32_t attributes_num = 0; - static void Build(ir::Builder &builder, // NOLINT - ir::OperationArgument &argument, // NOLINT - ir::OpResult cond, - std::vector &&output_types); - ir::Value cond() { return operand_source(0); } - ir::Block *true_block(); - ir::Block *false_block(); - void Print(ir::IrPrinter &printer); // NOLINT - void Verify(); -}; - -} // namespace dialect -} // namespace paddle - -IR_DECLARE_EXPLICIT_TYPE_ID(paddle::dialect::AddNOp) -IR_DECLARE_EXPLICIT_TYPE_ID(paddle::dialect::SplitGradOp) -IR_DECLARE_EXPLICIT_TYPE_ID(paddle::dialect::AddN_Op) -IR_DECLARE_EXPLICIT_TYPE_ID(paddle::dialect::AddNWithKernelOp) -IR_DECLARE_EXPLICIT_TYPE_ID(paddle::dialect::FusedGemmEpilogueOp) -IR_DECLARE_EXPLICIT_TYPE_ID(paddle::dialect::FusedGemmEpilogueGradOp) - -IR_DECLARE_EXPLICIT_TYPE_ID(paddle::dialect::IfOp) -#endif diff --git a/paddle/fluid/ir/dialect/paddle_dialect/ir/pd_manual_op_vjp.cc b/paddle/fluid/ir/dialect/paddle_dialect/ir/pd_manual_op_vjp.cc deleted file mode 100644 index c7bac02e3347e..0000000000000 --- a/paddle/fluid/ir/dialect/paddle_dialect/ir/pd_manual_op_vjp.cc +++ /dev/null @@ -1,55 +0,0 @@ -// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#include "paddle/fluid/ir/dialect/paddle_dialect/ir/pd_attribute.h" -#include "paddle/fluid/ir/dialect/paddle_dialect/ir/pd_op.h" -#include "paddle/fluid/primitive/rule/vjp/vjp.h" -#include "paddle/fluid/primitive/type/lazy_tensor.h" -#include "paddle/ir/core/builtin_op.h" -#include "paddle/ir/core/op_base.h" -#include "paddle/phi/common/int_array.h" - -// TODO(wanghao107) -// this file will be generated in pd_op.cc - -namespace paddle { -namespace dialect { -using IntArray = paddle::experimental::IntArray; - -std::vector> SumOp::Vjp( - ir::Operation* op, - const std::vector>& out_grads, - const std::vector>& stop_gradients) { - SumOp op_obj = op->dyn_cast(); - Tensor x(std::make_shared(op_obj.x())); - Tensor out_grad(std::make_shared(out_grads[0][0])); - - Tensor axis(std::make_shared(op_obj.axis())); - - bool keepdim = op->attribute("keepdim").dyn_cast().data(); - bool reduce_all = false; - std::vector> tensor_res = primitive::sum_vjp( - x, out_grad, axis, keepdim, reduce_all, stop_gradients); - std::vector> res(2, std::vector(1)); - if (tensor_res[0][0].defined()) { - res[0][0] = - std::static_pointer_cast(tensor_res[0][0].impl()) - ->getValue() - .dyn_cast(); - } - return res; -} - -} // namespace dialect -} // namespace paddle diff --git a/paddle/fluid/ir_adaptor/translator/CMakeLists.txt b/paddle/fluid/ir_adaptor/translator/CMakeLists.txt index 632411383db56..4ac1dc065143f 100644 --- a/paddle/fluid/ir_adaptor/translator/CMakeLists.txt +++ b/paddle/fluid/ir_adaptor/translator/CMakeLists.txt @@ -20,4 +20,4 @@ file(GLOB PD_PROGRAM_TRANSLATOR_SRCS "*.cc") cc_library( program_translator SRCS ${PD_PROGRAM_TRANSLATOR_SRCS} ${op_compat_source_file} - DEPS proto_desc pd_dialect ir framework_proto) + DEPS proto_desc pd_op_dialect pir framework_proto) diff --git a/paddle/fluid/ir_adaptor/translator/attribute_translator.cc b/paddle/fluid/ir_adaptor/translator/attribute_translator.cc index f6a4b94f2bfdf..ebb58cc0ebf61 100644 --- a/paddle/fluid/ir_adaptor/translator/attribute_translator.cc +++ b/paddle/fluid/ir_adaptor/translator/attribute_translator.cc @@ -17,14 +17,14 @@ #include #include -#include "paddle/fluid/ir/dialect/paddle_dialect/ir/pd_attribute.h" -#include "paddle/ir/core/enforce.h" +#include "paddle/fluid/pir/dialect/operator/ir/op_attribute.h" #include "paddle/phi/common/data_type.h" #include "paddle/phi/common/int_array.h" #include "paddle/phi/common/layout.h" #include "paddle/phi/common/place.h" #include "paddle/phi/common/scalar.h" #include "paddle/phi/core/utils/data_type.h" +#include "paddle/pir/core/enforce.h" #include "paddle/utils/variant.h" namespace paddle { @@ -32,127 +32,128 @@ namespace translator { class AttributeVisitor { public: - ir::IrContext* ctx; - AttributeVisitor() { ctx = ir::IrContext::Instance(); } + pir::IrContext* ctx; + AttributeVisitor() { ctx = pir::IrContext::Instance(); } ~AttributeVisitor() = default; public: - virtual ir::Attribute operator()(int i) { + virtual pir::Attribute operator()(int i) { VLOG(10) << "translating int"; - return ir::Int32Attribute::get(ctx, i); + return pir::Int32Attribute::get(ctx, i); } - virtual ir::Attribute operator()(int64_t i) { + virtual pir::Attribute operator()(int64_t i) { VLOG(10) << "translating int"; - return ir::Int64Attribute::get(ctx, i); + return pir::Int64Attribute::get(ctx, i); } - virtual ir::Attribute operator()(float f) { + virtual pir::Attribute operator()(float f) { VLOG(10) << "translating float"; - return ir::FloatAttribute::get(ctx, f); + return pir::FloatAttribute::get(ctx, f); } - virtual ir::Attribute operator()(bool b) { + virtual pir::Attribute operator()(bool b) { VLOG(10) << "translating bool"; - return ir::BoolAttribute::get(ctx, b); + return pir::BoolAttribute::get(ctx, b); } - virtual ir::Attribute operator()(double d) { + virtual pir::Attribute operator()(double d) { VLOG(10) << "translating double"; - return ir::DoubleAttribute::get(ctx, d); + return pir::DoubleAttribute::get(ctx, d); } - virtual ir::Attribute operator()(const std::string& str) { + virtual pir::Attribute operator()(const std::string& str) { VLOG(10) << "translating string"; - return ir::StrAttribute::get(ctx, str); + return pir::StrAttribute::get(ctx, str); } - virtual ir::Attribute operator()(const paddle::experimental::Scalar& scalar) { + virtual pir::Attribute operator()( + const paddle::experimental::Scalar& scalar) { VLOG(10) << "translating scalar"; IR_THROW("not support translating paddle::experimental::Scalar"); } - virtual ir::Attribute operator()(const std::vector& strs) { + virtual pir::Attribute operator()(const std::vector& strs) { VLOG(10) << "translating vector"; - std::vector attrs; + std::vector attrs; attrs.reserve(strs.size()); for (const auto& v : strs) { - attrs.push_back(ir::StrAttribute::get(ctx, v)); + attrs.push_back(pir::StrAttribute::get(ctx, v)); } - return ir::ArrayAttribute::get(ctx, attrs); + return pir::ArrayAttribute::get(ctx, attrs); } - virtual ir::Attribute operator()(const std::vector& fs) { + virtual pir::Attribute operator()(const std::vector& fs) { VLOG(10) << "translating vector"; - std::vector attrs; + std::vector attrs; attrs.reserve(fs.size()); for (const auto& v : fs) { - attrs.push_back(ir::FloatAttribute::get(ctx, v)); + attrs.push_back(pir::FloatAttribute::get(ctx, v)); } - return ir::ArrayAttribute::get(ctx, attrs); + return pir::ArrayAttribute::get(ctx, attrs); } - virtual ir::Attribute operator()(const std::vector& is) { + virtual pir::Attribute operator()(const std::vector& is) { VLOG(10) << "translating vector"; - std::vector attrs; + std::vector attrs; attrs.reserve(is.size()); for (const auto& v : is) { - attrs.push_back(ir::Int32Attribute::get(ctx, v)); + attrs.push_back(pir::Int32Attribute::get(ctx, v)); } - return ir::ArrayAttribute::get(ctx, attrs); + return pir::ArrayAttribute::get(ctx, attrs); } - virtual ir::Attribute operator()(const std::vector& bs) { + virtual pir::Attribute operator()(const std::vector& bs) { VLOG(10) << "translating vector"; - std::vector attrs; + std::vector attrs; attrs.reserve(bs.size()); for (const auto& v : bs) { - attrs.push_back(ir::BoolAttribute::get(ctx, v)); + attrs.push_back(pir::BoolAttribute::get(ctx, v)); } - return ir::ArrayAttribute::get(ctx, attrs); + return pir::ArrayAttribute::get(ctx, attrs); } - virtual ir::Attribute operator()(const std::vector& i64s) { + virtual pir::Attribute operator()(const std::vector& i64s) { VLOG(10) << "translating vector size: " << i64s.size(); - std::vector attrs; + std::vector attrs; attrs.reserve(i64s.size()); for (const auto& v : i64s) { - attrs.push_back(ir::Int64Attribute::get(ctx, v)); + attrs.push_back(pir::Int64Attribute::get(ctx, v)); } - return ir::ArrayAttribute::get(ctx, attrs); + return pir::ArrayAttribute::get(ctx, attrs); } - virtual ir::Attribute operator()(const std::vector& ds) { + virtual pir::Attribute operator()(const std::vector& ds) { VLOG(10) << "translating vector"; - std::vector attrs; + std::vector attrs; attrs.reserve(ds.size()); for (const auto& v : ds) { - attrs.push_back(ir::DoubleAttribute::get(ctx, v)); + attrs.push_back(pir::DoubleAttribute::get(ctx, v)); } - return ir::ArrayAttribute::get(ctx, attrs); + return pir::ArrayAttribute::get(ctx, attrs); } - virtual ir::Attribute operator()( + virtual pir::Attribute operator()( const std::vector& ss) { VLOG(10) << "translating vector"; - std::vector attrs; + std::vector attrs; attrs.reserve(ss.size()); for (const auto& v : ss) { attrs.push_back(dialect::ScalarAttribute::get(ctx, v)); } VLOG(10) << "translating vector Done"; - return ir::ArrayAttribute::get(ctx, attrs); + return pir::ArrayAttribute::get(ctx, attrs); } - virtual ir::Attribute operator()(const paddle::blank& blank) { + virtual pir::Attribute operator()(const paddle::blank& blank) { VLOG(10) << "translating paddle::blank"; - return ir::Attribute(nullptr); + return pir::Attribute(nullptr); } template - ir::Attribute operator()(T attr) { + pir::Attribute operator()(T attr) { VLOG(10) << "translating null type"; - return ir::Attribute(nullptr); + return pir::Attribute(nullptr); } }; @@ -160,19 +161,19 @@ class Int64ArrayAttributeVisitor : public AttributeVisitor { public: using AttributeVisitor::AttributeVisitor; - ir::Attribute operator()(const std::vector& is) override { + pir::Attribute operator()(const std::vector& is) override { VLOG(10) << "translating vector"; - std::vector attrs; + std::vector attrs; attrs.reserve(is.size()); for (const auto& v : is) { - attrs.push_back(ir::Int64Attribute::get(ctx, v)); + attrs.push_back(pir::Int64Attribute::get(ctx, v)); } - return ir::ArrayAttribute::get(ctx, attrs); + return pir::ArrayAttribute::get(ctx, attrs); } - ir::Attribute operator()(const paddle::blank& blank) override { + pir::Attribute operator()(const paddle::blank& blank) override { VLOG(10) << "translating paddle::blank to int64[]"; - return ir::ArrayAttribute::get(ctx, {}); + return pir::ArrayAttribute::get(ctx, {}); } }; @@ -180,22 +181,22 @@ class Int64AttributeVisitor : public AttributeVisitor { public: using AttributeVisitor::AttributeVisitor; - ir::Attribute operator()(int is) override { + pir::Attribute operator()(int is) override { VLOG(10) << "translating int to Int64Attribute"; - return ir::Int64Attribute::get(ctx, is); + return pir::Int64Attribute::get(ctx, is); } }; class IntArrayAttributeVisitor : public AttributeVisitor { public: using AttributeVisitor::AttributeVisitor; - ir::Attribute operator()(const std::vector& is) override { + pir::Attribute operator()(const std::vector& is) override { VLOG(10) << "translating vector to IntArray"; phi::IntArray data(is); return paddle::dialect::IntArrayAttribute::get(ctx, data); } - ir::Attribute operator()(const std::vector& is) override { + pir::Attribute operator()(const std::vector& is) override { VLOG(10) << "translating vector to IntArray"; phi::IntArray data(is); return paddle::dialect::IntArrayAttribute::get(ctx, data); @@ -205,14 +206,14 @@ class IntArrayAttributeVisitor : public AttributeVisitor { class DataTypeAttributeVisitor : public AttributeVisitor { public: using AttributeVisitor::AttributeVisitor; - ir::Attribute operator()(int i) override { + pir::Attribute operator()(int i) override { VLOG(10) << "translating int to DataType: " << i; auto phi_dtype = phi::TransToPhiDataType(i); return paddle::dialect::DataTypeAttribute::get(ctx, phi_dtype); } - ir::Attribute operator()(const paddle::blank& blank) override { + pir::Attribute operator()(const paddle::blank& blank) override { VLOG(10) << "translating paddle::blank to DataType::UNDEFINED"; return paddle::dialect::DataTypeAttribute::get(ctx, phi::DataType()); } @@ -222,7 +223,7 @@ class PlaceAttributeVisitor : public AttributeVisitor { public: using AttributeVisitor::AttributeVisitor; - ir::Attribute operator()(const paddle::blank& blank) override { + pir::Attribute operator()(const paddle::blank& blank) override { VLOG(10) << "translating paddle::blank to Place::UNDEFINED"; phi::Place data(phi::AllocationType::UNDEFINED); return paddle::dialect::PlaceAttribute::get(ctx, data); @@ -237,17 +238,17 @@ AttributeTranslator::AttributeTranslator() { new DataTypeAttributeVisitor(); special_visitors["paddle::dialect::PlaceAttribute"] = new PlaceAttributeVisitor(); - special_visitors["ir::ArrayAttribute"] = + special_visitors["pir::ArrayAttribute"] = new Int64ArrayAttributeVisitor(); - special_visitors["ir::Int64Attribute"] = new Int64AttributeVisitor(); + special_visitors["pir::Int64Attribute"] = new Int64AttributeVisitor(); } -ir::Attribute AttributeTranslator::operator()( +pir::Attribute AttributeTranslator::operator()( const framework::Attribute& attr) { return paddle::visit(*general_visitor, attr); } -ir::Attribute AttributeTranslator::operator()( +pir::Attribute AttributeTranslator::operator()( const std::string& target_type, const framework::Attribute& attr) { if (special_visitors.find(target_type) == special_visitors.end()) { VLOG(10) << "[" << target_type << "] not found"; diff --git a/paddle/fluid/ir_adaptor/translator/attribute_translator.h b/paddle/fluid/ir_adaptor/translator/attribute_translator.h index ea509c7e34673..2a716b0ef7d18 100644 --- a/paddle/fluid/ir_adaptor/translator/attribute_translator.h +++ b/paddle/fluid/ir_adaptor/translator/attribute_translator.h @@ -17,9 +17,9 @@ #include "paddle/fluid/framework/attribute.h" #include "paddle/fluid/framework/type_defs.h" -#include "paddle/ir/core/attribute.h" -#include "paddle/ir/core/builtin_attribute.h" -#include "paddle/ir/core/ir_context.h" +#include "paddle/pir/core/attribute.h" +#include "paddle/pir/core/builtin_attribute.h" +#include "paddle/pir/core/ir_context.h" #pragma once @@ -45,9 +45,9 @@ class AttributeTranslator { return attribute_translator; } - ir::Attribute operator()(const framework::Attribute& attr); - ir::Attribute operator()(const std::string& target_type, - const framework::Attribute& attr); + pir::Attribute operator()(const framework::Attribute& attr); + pir::Attribute operator()(const std::string& target_type, + const framework::Attribute& attr); }; } // namespace translator diff --git a/paddle/fluid/ir_adaptor/translator/op_translator.cc b/paddle/fluid/ir_adaptor/translator/op_translator.cc index 39a6acdd21b55..b441fe6c87b69 100644 --- a/paddle/fluid/ir_adaptor/translator/op_translator.cc +++ b/paddle/fluid/ir_adaptor/translator/op_translator.cc @@ -23,27 +23,27 @@ #include #include "paddle/fluid/framework/op_desc.h" -#include "paddle/fluid/ir/dialect/paddle_dialect/interface/op_yaml_info.h" -#include "paddle/fluid/ir/dialect/paddle_dialect/ir/pd_attribute.h" -#include "paddle/fluid/ir/dialect/paddle_dialect/ir/pd_dialect.h" -#include "paddle/fluid/ir/dialect/paddle_dialect/ir/pd_type.h" -#include "paddle/fluid/ir/dialect/paddle_dialect/utils/utils.h" #include "paddle/fluid/ir_adaptor/translator/attribute_translator.h" #include "paddle/fluid/ir_adaptor/translator/op_compat_info.h" #include "paddle/fluid/ir_adaptor/translator/program_translator.h" #include "paddle/fluid/ir_adaptor/translator/type_translator.h" #include "paddle/fluid/ir_adaptor/translator/utils.h" -#include "paddle/ir/core/builder.h" -#include "paddle/ir/core/builtin_op.h" -#include "paddle/ir/core/builtin_type.h" -#include "paddle/ir/core/enforce.h" -#include "paddle/ir/core/ir_context.h" -#include "paddle/ir/core/operation.h" -#include "paddle/ir/core/value.h" +#include "paddle/fluid/pir/dialect/operator/interface/op_yaml_info.h" +#include "paddle/fluid/pir/dialect/operator/ir/op_attribute.h" +#include "paddle/fluid/pir/dialect/operator/ir/op_dialect.h" +#include "paddle/fluid/pir/dialect/operator/ir/op_type.h" +#include "paddle/fluid/pir/dialect/operator/utils/utils.h" +#include "paddle/pir/core/builder.h" +#include "paddle/pir/core/builtin_op.h" +#include "paddle/pir/core/builtin_type.h" +#include "paddle/pir/core/enforce.h" +#include "paddle/pir/core/ir_context.h" +#include "paddle/pir/core/operation.h" +#include "paddle/pir/core/value.h" // NOTE(zhangbo9674): File pd_op.h is generated by op_gen.py, see details in -// paddle/fluid/ir/dialect/CMakeLists.txt. -#include "paddle/fluid/ir/dialect/paddle_dialect/ir/pd_op.h" +// paddle/fluid/pir/dialect/CMakeLists.txt. +#include "paddle/fluid/pir/dialect/operator/ir/pd_op.h" namespace paddle { namespace translator { @@ -56,7 +56,7 @@ using ResultIdx = std::tuple; using OpDesc = paddle::framework::OpDesc; using BlockDesc = paddle::framework::BlockDesc; using VarDesc = paddle::framework::VarDesc; -using OpOutputTypeList = std::vector; +using OpOutputTypeList = std::vector; using OpOutputMapping = std::unordered_map; using OpInputInfo = dialect::OpInputInfo; using OpInputInfoList = std::vector; @@ -64,16 +64,16 @@ using OpAttributeInfo = dialect::OpAttributeInfo; using OpAttributeInfoList = std::vector; using OpOutputInfo = dialect::OpOutputInfo; using OpOutputInfoList = std::vector; -using InputHandlerFn = std::function; -using AttributeHandlerFn = std::function; -constexpr char kTargetDialectPrefix[] = "pd."; // NOLINT -constexpr char kEmptyVarName[] = "@EMPTY@"; // NOLINT +using InputHandlerFn = std::function; +using AttributeHandlerFn = std::function; +constexpr char kTargetDialectPrefix[] = "pd_op."; // NOLINT +constexpr char kEmptyVarName[] = "@EMPTY@"; // NOLINT static const std::unordered_set SpecialNonInplaceOps = {}; @@ -126,47 +126,46 @@ inline std::string OpNameCompatibleMapping(std::string op_name) { return op_normalizer[op_name]; } -inline ir::Operation* InsertCombineOperationForTarget( - ir::IrContext* ctx, +inline pir::Operation* InsertCombineOperationForTarget( + pir::IrContext* ctx, TranslationContext* param_map, - ir::Program* program, + pir::Program* program, const std::vector& args) { - std::string combine_op_name(ir::CombineOp::name()); - ir::OpInfo op_info = ctx->GetRegisteredOpInfo(combine_op_name); + std::string combine_op_name(pir::CombineOp::name()); + pir::OpInfo op_info = ctx->GetRegisteredOpInfo(combine_op_name); - std::vector src_values; - std::vector types_in_vec; + std::vector src_values; + std::vector types_in_vec; for (const auto& arg_name : args) { auto defining_info = param_map->at(arg_name); src_values.push_back(defining_info.value); types_in_vec.push_back(defining_info.value.type()); } - ir::Type target_vec_type = ir::VectorType::get(ctx, types_in_vec); - ir::Operation* operation = - ir::Operation::Create(src_values, {}, {target_vec_type}, op_info); + pir::Type target_vec_type = pir::VectorType::get(ctx, types_in_vec); + pir::Operation* operation = + pir::Operation::Create(src_values, {}, {target_vec_type}, op_info); program->block()->push_back(operation); return operation; } -inline ir::Operation* InsertFullOperationForAttributeInput(ir::IrContext* ctx, - ir::Program* program, - ir::Attribute attr) { +inline pir::Operation* InsertFullOperationForAttributeInput( + pir::IrContext* ctx, pir::Program* program, pir::Attribute attr) { float data = 0.0f; phi::DataType dtype = phi::DataType::UNDEFINED; - if (attr.isa()) { - data = attr.dyn_cast().data(); + if (attr.isa()) { + data = attr.dyn_cast().data(); dtype = phi::DataType::FLOAT32; - } else if (attr.isa()) { - data = static_cast(attr.dyn_cast().data()); + } else if (attr.isa()) { + data = static_cast(attr.dyn_cast().data()); dtype = phi::DataType::FLOAT64; - } else if (attr.isa()) { - data = static_cast(attr.dyn_cast().data()); + } else if (attr.isa()) { + data = static_cast(attr.dyn_cast().data()); dtype = phi::DataType::INT32; - } else if (attr.isa()) { - data = static_cast(attr.dyn_cast().data()); + } else if (attr.isa()) { + data = static_cast(attr.dyn_cast().data()); dtype = phi::DataType::INT64; - } else if (attr.isa()) { - data = static_cast(attr.dyn_cast().data()); + } else if (attr.isa()) { + data = static_cast(attr.dyn_cast().data()); dtype = phi::DataType::BOOL; } else if (attr.isa()) { // TODO(phlrain) : need update here, downcast from double to float @@ -174,35 +173,35 @@ inline ir::Operation* InsertFullOperationForAttributeInput(ir::IrContext* ctx, attr.dyn_cast().data().to()); dtype = phi::DataType::FLOAT64; } - ir::Builder builder(ctx, program->block()); + pir::Builder builder(ctx, program->block()); dialect::FullOp full_op = builder.Build( std::vector{1}, data, dtype, phi::CPUPlace()); return full_op.operation(); } -inline ir::Operation* InsertFullArrayOperationForAttributeInput( - ir::IrContext* ctx, ir::Program* program, ir::Attribute attr) { +inline pir::Operation* InsertFullArrayOperationForAttributeInput( + pir::IrContext* ctx, pir::Program* program, pir::Attribute attr) { IR_ENFORCE(attr.isa(), "Encounter non IntArray type when trying to insert IntArray " "mutable attribute"); phi::IntArray int_array = attr.dyn_cast().data(); - ir::Builder builder(ctx, program->block()); + pir::Builder builder(ctx, program->block()); dialect::FullIntArrayOp full_int_array_op = builder.Build( int_array.GetData(), phi::DataType::INT64, phi::CPUPlace()); return full_int_array_op.operation(); } -inline ir::Operation* InsertStackOperationForTarget( - ir::IrContext* ctx, +inline pir::Operation* InsertStackOperationForTarget( + pir::IrContext* ctx, TranslationContext* param_map, - ir::Program* program, + pir::Program* program, const std::vector& args, int axis = 0) { auto* combine_op = InsertCombineOperationForTarget(ctx, param_map, program, args); - ir::Builder builder(ctx, program->block()); + pir::Builder builder(ctx, program->block()); dialect::StackOp stack_op = builder.Build(combine_op->result(0), axis); return stack_op.operation(); @@ -210,8 +209,8 @@ inline ir::Operation* InsertStackOperationForTarget( } // namespace -ir::OpInfo OpTranscriber::LoopkUpOpInfo(ir::IrContext* ctx, - const OpDesc& op_desc) { +pir::OpInfo OpTranscriber::LoopkUpOpInfo(pir::IrContext* ctx, + const OpDesc& op_desc) { std::string target_op_name = kTargetDialectPrefix + OpNameCompatibleMapping(op_desc.Type()); if (IsInplace(op_desc) && *target_op_name.rbegin() != '_') { @@ -230,11 +229,11 @@ ir::OpInfo OpTranscriber::LoopkUpOpInfo(ir::IrContext* ctx, } void OpTranscriber::InsertSliceOperationForInput( - ir::IrContext* ctx, + pir::IrContext* ctx, TranslationContext* param_map, const OpDesc& op_desc, const OpInputInfoList& input_infos, - ir::Program* program) { + pir::Program* program) { auto& op_normalizer = OpNameNormalizer::instance(); std::set yaml_input_set; for (const auto& info : input_infos) { @@ -265,10 +264,11 @@ void OpTranscriber::InsertSliceOperationForInput( } } -ir::OpResult OpTranscriber::GetAttributeAsInput(ir::IrContext* ctx, - ir::Program* program, - const OpDesc& op_desc, - const OpInputInfo& input_info) { +pir::OpResult OpTranscriber::GetAttributeAsInput( + pir::IrContext* ctx, + pir::Program* program, + const OpDesc& op_desc, + const OpInputInfo& input_info) { auto& attribute_translator = AttributeTranslator::instance(); auto& op_normalizer = OpNameNormalizer::instance(); @@ -283,10 +283,10 @@ ir::OpResult OpTranscriber::GetAttributeAsInput(ir::IrContext* ctx, paddle::framework::Attribute legacy_attr = op_desc.GetAttr(legacy_attr_name); VLOG(10) << "[" << op_desc.Type() << "][attribute]" << " name: " << legacy_attr_name << " " << legacy_attr.index(); - ir::Attribute new_attr = + pir::Attribute new_attr = attribute_translator(input_info.type_name, legacy_attr); - ir::Operation* defining_op = nullptr; + pir::Operation* defining_op = nullptr; bool is_int_array = (input_info.type_name.find("IntArrayAttribute") != input_info.type_name.npos); if (is_int_array) { @@ -299,13 +299,13 @@ ir::OpResult OpTranscriber::GetAttributeAsInput(ir::IrContext* ctx, return defining_op->result(0); } -std::vector OpTranscriber::GenerateOperationInput( - ir::IrContext* ctx, +std::vector OpTranscriber::GenerateOperationInput( + pir::IrContext* ctx, TranslationContext* param_map, const OpDesc& op_desc, const std::string& normalized_op_name, const OpInputInfoList& input_infos, - ir::Program* program) { + pir::Program* program) { VLOG(10) << "[op:" << op_desc.Type() << "][input] entrance"; auto& op_normalizer = OpNameNormalizer::instance(); @@ -314,11 +314,11 @@ std::vector OpTranscriber::GenerateOperationInput( VLOG(10) << "[op:" << op_desc.Type() << "][input] start"; - std::vector op_inputs; + std::vector op_inputs; for (const auto& info : input_infos) { if (auto special_handler = this->GetSpecialInputHandlers(info.name)) { - ir::OpResult ret = special_handler( + pir::OpResult ret = special_handler( ctx, param_map, op_desc, normalized_op_name, info, program); op_inputs.push_back(ret); continue; @@ -407,7 +407,7 @@ std::vector OpTranscriber::GenerateOperationInput( } std::tuple -OpTranscriber::GenerateOperationOutput(ir::IrContext* ctx, +OpTranscriber::GenerateOperationOutput(pir::IrContext* ctx, const OpDesc& op_desc, const OpOutputInfoList& output_infos) { OpOutputMapping arg_to_idx; @@ -457,7 +457,7 @@ OpTranscriber::GenerateOperationOutput(ir::IrContext* ctx, legacy_output_vars[0]); if (var->GetType() == paddle::framework::proto::VarType::LOD_TENSOR_ARRAY) { - ir::Type translated_var_type = + pir::Type translated_var_type = type_translator[var->GetType()](ctx, *var); op_output_types.push_back(translated_var_type); arg_to_idx[var->Name()] = {cur_output_idx, 0}; @@ -486,7 +486,8 @@ OpTranscriber::GenerateOperationOutput(ir::IrContext* ctx, << "[" << op_desc.Type() << "]" << info.name << " var: " << var_name << " type: " << var->GetType(); - ir::Type translated_var_type = type_translator[var->GetType()](ctx, *var); + pir::Type translated_var_type = + type_translator[var->GetType()](ctx, *var); arg_to_idx[var_name] = {cur_output_idx, 0}; op_output_types.push_back(translated_var_type); @@ -496,7 +497,7 @@ OpTranscriber::GenerateOperationOutput(ir::IrContext* ctx, VLOG(10) << "[output translating]" << "[" << op_desc.Type() << "]" << info.name << " :" << info.type_name << " var: " << legacy_output_name; - std::vector types; + std::vector types; for (IdxInVector idx_in_vec = 0; idx_in_vec < legacy_output_vars.size(); idx_in_vec++) { const auto& var_name = legacy_output_vars[idx_in_vec]; @@ -509,26 +510,26 @@ OpTranscriber::GenerateOperationOutput(ir::IrContext* ctx, VLOG(10) << "[output translating]" << "[" << op_desc.Type() << "]" << info.name << " var: " << var_name << " type: " << var->GetType(); - ir::Type translated_var_type = + pir::Type translated_var_type = type_translator[var->GetType()](ctx, *var); types.push_back(translated_var_type); arg_to_idx[var_name] = {cur_output_idx, idx_in_vec}; } - ir::Type vec_type = ir::VectorType::get(ctx, types); + pir::Type vec_type = pir::VectorType::get(ctx, types); op_output_types.push_back(vec_type); } } return {op_output_types, arg_to_idx}; } -ir::AttributeMap OpTranscriber::TranslateOpAttribute( - ir::IrContext* ctx, +pir::AttributeMap OpTranscriber::TranslateOpAttribute( + pir::IrContext* ctx, const std::string& normalized_op_name, const OpAttributeInfoList& op_attr_infos, const OpDesc& op_desc) { auto& attribute_translator = AttributeTranslator::instance(); auto& op_normalizer = OpNameNormalizer::instance(); - ir::AttributeMap attribute_map = {}; + pir::AttributeMap attribute_map = {}; for (const auto& info : op_attr_infos) { if (auto handler = this->GetSpecialAttributeHandlers(info.name)) { @@ -546,7 +547,7 @@ ir::AttributeMap OpTranscriber::TranslateOpAttribute( op_desc.GetAttr(legacy_attr_name); VLOG(10) << "attribute in " << op_desc.Type() << " name: " << legacy_attr_name << " " << legacy_attr.index(); - ir::Attribute new_attr = + pir::Attribute new_attr = attribute_translator(info.type_name, legacy_attr); attribute_map[info.name] = new_attr; if (!new_attr) { @@ -563,36 +564,36 @@ ir::AttributeMap OpTranscriber::TranslateOpAttribute( return attribute_map; } -void OpTranscriber::HandleNonexistentAttribute(ir::IrContext*, - ir::AttributeMap* attribute_map, +void OpTranscriber::HandleNonexistentAttribute(pir::IrContext*, + pir::AttributeMap* attribute_map, const OpAttributeInfo& info) { auto& attribute_translator = AttributeTranslator::instance(); (*attribute_map)[info.name] = attribute_translator(info.type_name, paddle::framework::Attribute()); } -void OpTranscriber::RecordOpResultMapping(ir::IrContext* ctx, +void OpTranscriber::RecordOpResultMapping(pir::IrContext* ctx, TranslationContext* param_map, const OpDesc& op_desc, - ir::Operation* operation, + pir::Operation* operation, const OpOutputMapping& arg_to_idx) { for (const auto& [arg_name, idx] : arg_to_idx) { const auto& [idx_in_op, idx_in_vec] = idx; VLOG(10) << "[output recording]" << "[" << op_desc.Type() << "]" << arg_name << " " << idx_in_op << " " << idx_in_vec; - ir::OpResult value = operation->result(idx_in_op); - bool generated_by_vector = value.type().isa(); + pir::OpResult value = operation->result(idx_in_op); + bool generated_by_vector = value.type().isa(); (*param_map)[arg_name] = VariableDefiningInfo( value, generated_by_vector, generated_by_vector ? idx_in_vec : -1); } } -ir::Operation* OpTranscriber::operator()(ir::IrContext* ctx, - TranslationContext* param_map, - const OpDesc& op_desc, - ir::Program* program) { +pir::Operation* OpTranscriber::operator()(pir::IrContext* ctx, + TranslationContext* param_map, + const OpDesc& op_desc, + pir::Program* program) { auto op_info = this->LoopkUpOpInfo(ctx, op_desc); auto* op_info_concept = op_info.GetInterfaceImpl(); @@ -618,8 +619,8 @@ ir::Operation* OpTranscriber::operator()(ir::IrContext* ctx, this->TranslateOpAttribute(ctx, op_info.name(), attr_infos, op_desc); VLOG(4) << "[general op][" << op_desc.Type() << "] preparation end."; - ir::Operation* operation = - ir::Operation::Create(op_inputs, attribute_map, op_output_types, op_info); + pir::Operation* operation = pir::Operation::Create( + op_inputs, attribute_map, op_output_types, op_info); VLOG(4) << "[general op][" << op_desc.Type() << "] opearation creation end."; program->block()->push_back(operation); @@ -630,13 +631,13 @@ ir::Operation* OpTranscriber::operator()(ir::IrContext* ctx, } struct CastOpTranscriber : public OpTranscriber { - ir::AttributeMap TranslateOpAttribute( - ir::IrContext*, + pir::AttributeMap TranslateOpAttribute( + pir::IrContext*, const std::string& normalized_op_name, const OpAttributeInfoList& op_attr_infos, const OpDesc& op_desc) override { auto& attribute_translator = AttributeTranslator::instance(); - ir::AttributeMap attribute_map = {}; + pir::AttributeMap attribute_map = {}; const OpAttributeInfo info = op_attr_infos[0]; std::string legacy_attr_name("out_dtype"); @@ -647,7 +648,7 @@ struct CastOpTranscriber : public OpTranscriber { } VLOG(10) << "attribute in " << op_desc.Type() << " name: " << legacy_attr_name << " " << legacy_attr.index(); - ir::Attribute new_attr = attribute_translator(info.type_name, legacy_attr); + pir::Attribute new_attr = attribute_translator(info.type_name, legacy_attr); attribute_map[info.name] = new_attr; return attribute_map; @@ -655,35 +656,35 @@ struct CastOpTranscriber : public OpTranscriber { }; struct EmbeddingOpTranscriber : public OpTranscriber { - void HandleNonexistentAttribute(ir::IrContext* ctx, - ir::AttributeMap* attribute_map, + void HandleNonexistentAttribute(pir::IrContext* ctx, + pir::AttributeMap* attribute_map, const OpAttributeInfo& info) override { if (info.name == "padding_idx") { - (*attribute_map)[info.name] = ir::Int64Attribute::get(ctx, -1); + (*attribute_map)[info.name] = pir::Int64Attribute::get(ctx, -1); } else if (info.name == "sparse") { - (*attribute_map)[info.name] = ir::BoolAttribute::get(ctx, false); + (*attribute_map)[info.name] = pir::BoolAttribute::get(ctx, false); } } }; struct IncrementOpTranscriber : public OpTranscriber { - ir::AttributeMap TranslateOpAttribute( - ir::IrContext* ctx, + pir::AttributeMap TranslateOpAttribute( + pir::IrContext* ctx, const std::string& normalized_op_name, const OpAttributeInfoList& op_attr_infos, const OpDesc& op_desc) override { auto& attribute_translator = AttributeTranslator::instance(); - ir::AttributeMap attribute_map = {}; + pir::AttributeMap attribute_map = {}; paddle::framework::Attribute legacy_attr; if (op_desc.HasAttr("step")) { legacy_attr = op_desc.GetAttr("step"); VLOG(10) << "attribute in " << op_desc.Type() << " step: " << " " << legacy_attr.index(); - ir::Attribute new_attr = attribute_translator(legacy_attr); + pir::Attribute new_attr = attribute_translator(legacy_attr); attribute_map["value"] = new_attr; } else { - attribute_map["value"] = ir::FloatAttribute::get(ctx, 1.0f); + attribute_map["value"] = pir::FloatAttribute::get(ctx, 1.0f); } return attribute_map; @@ -694,21 +695,23 @@ struct IncrementOpTranscriber : public OpTranscriber { // `legacy_ops.yaml`. For this op we simulate the logic in // python/paddle/tensor/creation.py::assign(x, output) struct AssignValueOpTranscriber : public OpTranscriber { - ir::OpInfo LoopkUpOpInfo(ir::IrContext* ctx, const OpDesc& op_desc) override { - std::string target_op_name = "pd.assign_value"; + pir::OpInfo LoopkUpOpInfo(pir::IrContext* ctx, + const OpDesc& op_desc) override { + std::string target_op_name = "pd_op.assign_value"; const auto& op_info = ctx->GetRegisteredOpInfo(target_op_name); if (!op_info) { IR_THROW( - "Op assign_value should have corresponding OpInfo pd.assign_value"); + "Op assign_value should have corresponding OpInfo " + "pd_op.assign_value"); } return op_info; } - ir::Operation* operator()(ir::IrContext* ctx, - TranslationContext* param_map, - const OpDesc& op_desc, - ir::Program* program) override { + pir::Operation* operator()(pir::IrContext* ctx, + TranslationContext* param_map, + const OpDesc& op_desc, + pir::Program* program) override { VLOG(10) << "[op assign_value] start transcribing"; auto op_info = this->LoopkUpOpInfo(ctx, op_desc); auto* op_info_concept = @@ -724,7 +727,7 @@ struct AssignValueOpTranscriber : public OpTranscriber { } auto& attribute_translator = AttributeTranslator::instance(); - ir::AttributeMap attribute_map; + pir::AttributeMap attribute_map; paddle::framework::Attribute legacy_attr; if (op_desc.HasAttr("shape")) { @@ -732,7 +735,7 @@ struct AssignValueOpTranscriber : public OpTranscriber { } else { IR_THROW("Op assign_value should have attribute `shape` but not find"); } - ir::Attribute attr_shape = + pir::Attribute attr_shape = attribute_translator(attr_info_maps.at("shape").type_name, legacy_attr); attribute_map["shape"] = attr_shape; @@ -741,11 +744,11 @@ struct AssignValueOpTranscriber : public OpTranscriber { } else { IR_THROW("Op assign_value should have attribute `dtype` but not find"); } - ir::Attribute attr_dtype = + pir::Attribute attr_dtype = attribute_translator(attr_info_maps.at("dtype").type_name, legacy_attr); attribute_map["dtype"] = attr_dtype; - ir::Attribute attr_place = + pir::Attribute attr_place = dialect::PlaceAttribute::get(ctx, phi::CPUPlace()); attribute_map["place"] = attr_place; @@ -764,20 +767,20 @@ struct AssignValueOpTranscriber : public OpTranscriber { "Op assign_value should have attribute `**_values` but not find"); } - ir::Attribute attr_values = attribute_translator( + pir::Attribute attr_values = attribute_translator( attr_info_maps.at("values").type_name, legacy_attr); attribute_map["values"] = attr_values; VLOG(10) << "[op assign_value] attribute translation done"; - std::vector op_inputs = {}; + std::vector op_inputs = {}; OpOutputMapping arg_to_idx; OpOutputTypeList op_output_types; std::tie(op_output_types, arg_to_idx) = this->GenerateOperationOutput(ctx, op_desc, output_infos); - ir::Operation* operation = ir::Operation::Create( + pir::Operation* operation = pir::Operation::Create( op_inputs, attribute_map, op_output_types, op_info); program->block()->push_back(operation); RecordOpResultMapping(ctx, param_map, op_desc, operation, arg_to_idx); @@ -792,12 +795,12 @@ struct AssignValueOpTranscriber : public OpTranscriber { // So we generate an input by `full` with same type of output `DropoutState` of // OpDesc And we still should be aware that `DropoutState` is an optional output // in static graph. -ir::OpResult TranslateDropOutStateIn(ir::IrContext* ctx, - TranslationContext* param_map, - const OpDesc& op_desc, - const std::string& normalized_op_name, - const OpInputInfo& input_info, - ir::Program* program) { +pir::OpResult TranslateDropOutStateIn(pir::IrContext* ctx, + TranslationContext* param_map, + const OpDesc& op_desc, + const std::string& normalized_op_name, + const OpInputInfo& input_info, + pir::Program* program) { const std::string legacy_output_name = "DropoutState"; std::vector legacy_output_vars; if (op_desc.HasOutput(legacy_output_name)) { @@ -806,7 +809,7 @@ ir::OpResult TranslateDropOutStateIn(ir::IrContext* ctx, if (legacy_output_vars.empty()) { VLOG(3) << "[input translating] not find output variable: DropoutState"; - return ir::OpResult(nullptr); + return pir::OpResult(nullptr); } // `DropoutState` is a tensor @@ -816,14 +819,14 @@ ir::OpResult TranslateDropOutStateIn(ir::IrContext* ctx, IR_THROW("Unexpected: Rnn Op should have a non-empty DropoutState"); } auto& type_translator = TypeTranslator::instance(); - ir::Type translated_var_type = + pir::Type translated_var_type = type_translator[dropout_state->GetType()](ctx, *dropout_state); IR_ENFORCE( translated_var_type.isa(), "Unexpected: Rnn Op's output DropoutState should be a DenseTensor"); auto tensor_type = translated_var_type.dyn_cast(); - ir::Builder builder(ctx, program->block()); + pir::Builder builder(ctx, program->block()); dialect::FullOp full_op = builder.Build( phi::vectorize(tensor_type.dims()), 0.0f, @@ -845,26 +848,27 @@ struct RnnOpTranscriber : public OpTranscriber { }; struct EmbeddingGradOpTranscriber : public OpTranscriber { - void HandleNonexistentAttribute(ir::IrContext* ctx, - ir::AttributeMap* attribute_map, + void HandleNonexistentAttribute(pir::IrContext* ctx, + pir::AttributeMap* attribute_map, const OpAttributeInfo& info) override { if (info.name == "padding_idx") { - (*attribute_map)[info.name] = ir::Int64Attribute::get(ctx, -1); + (*attribute_map)[info.name] = pir::Int64Attribute::get(ctx, -1); } else if (info.name == "sparse") { - (*attribute_map)[info.name] = ir::BoolAttribute::get(ctx, false); + (*attribute_map)[info.name] = pir::BoolAttribute::get(ctx, false); } } - ir::OpInfo LoopkUpOpInfo(ir::IrContext* ctx, const OpDesc& op_desc) override { + pir::OpInfo LoopkUpOpInfo(pir::IrContext* ctx, + const OpDesc& op_desc) override { std::string target_op_name = kTargetDialectPrefix + OpNameCompatibleMapping(op_desc.Type()); bool is_sparse = paddle::get(op_desc.GetAttr("is_sparse")); if (is_sparse) { - target_op_name = "pd.embedding_grad_sparse"; + target_op_name = "pd_op.embedding_grad_sparse"; } else { - target_op_name = "pd.embedding_grad_dense"; + target_op_name = "pd_op.embedding_grad_dense"; } VLOG(6) << "[op name normalizing: " << op_desc.Type() << " to " << target_op_name; @@ -880,45 +884,45 @@ struct EmbeddingGradOpTranscriber : public OpTranscriber { }; struct FeedOpTranscriber : public OpTranscriber { - ir::AttributeMap TranslateOpAttribute( - ir::IrContext* ctx, + pir::AttributeMap TranslateOpAttribute( + pir::IrContext* ctx, const std::string& normalized_op_name, const OpAttributeInfoList& op_attr_infos, const OpDesc& op_desc) override { - ir::AttributeMap attribute_map = { - {"name", ir::StrAttribute::get(ctx, op_desc.OutputArgumentNames()[0])}, + pir::AttributeMap attribute_map = { + {"name", pir::StrAttribute::get(ctx, op_desc.OutputArgumentNames()[0])}, {"col", - ir::Int32Attribute::get(ctx, op_desc.GetAttrIfExists("col"))}, + pir::Int32Attribute::get(ctx, op_desc.GetAttrIfExists("col"))}, }; return attribute_map; } - std::vector GenerateOperationInput( - ir::IrContext* ctx, + std::vector GenerateOperationInput( + pir::IrContext* ctx, TranslationContext* param_map, const OpDesc& op_desc, const std::string& normalized_op_name, const OpInputInfoList& input_infos, - ir::Program* program) override { + pir::Program* program) override { return {}; } }; struct DataOpTranscriber : public FeedOpTranscriber { - ir::AttributeMap TranslateOpAttribute( - ir::IrContext* ctx, + pir::AttributeMap TranslateOpAttribute( + pir::IrContext* ctx, const std::string& normalized_op_name, const OpAttributeInfoList& op_attr_infos, const OpDesc& op_desc) override { int allocate_type = paddle::get(op_desc.GetAttr("place")); auto& attribute_translator = AttributeTranslator::instance(); - ir::Attribute shape = attribute_translator( + pir::Attribute shape = attribute_translator( "paddle::dialect::IntArrayAttribute", op_desc.GetAttr("shape")); - ir::AttributeMap attribute_map = { + pir::AttributeMap attribute_map = { {"name", - ir::StrAttribute::get(ctx, - op_desc.GetAttrIfExists("name"))}, + pir::StrAttribute::get(ctx, + op_desc.GetAttrIfExists("name"))}, {"shape", shape}, {"dtype", paddle::dialect::DataTypeAttribute::get(ctx, phi::DataType::FLOAT32)}, @@ -932,18 +936,18 @@ struct DataOpTranscriber : public FeedOpTranscriber { }; struct SplitOpTranscriber : public OpTranscriber { - std::vector GenerateOperationInput( - ir::IrContext* ctx, + std::vector GenerateOperationInput( + pir::IrContext* ctx, TranslationContext* param_map, const OpDesc& op_desc, const std::string& normalized_op_name, const OpInputInfoList& input_infos, - ir::Program* program) override { + pir::Program* program) override { // input of split is [Tensor x, IntArray sections, Scalar(int) axis)] VLOG(10) << "[op:split][input] start"; - std::vector op_inputs; + std::vector op_inputs; // process first input auto x_input_vars = op_desc.Input("X"); IR_ENFORCE(x_input_vars.size() == 1, "x input of split MUST be a tensor"); @@ -963,7 +967,7 @@ struct SplitOpTranscriber : public OpTranscriber { op_inputs.push_back(combine_op->result(0)); } else { auto& attribute_translator = AttributeTranslator::instance(); - ir::Attribute new_attr = attribute_translator( + pir::Attribute new_attr = attribute_translator( "paddle::dialect::IntArrayAttribute", op_desc.GetAttr("sections")); auto sec_defin_op = InsertFullArrayOperationForAttributeInput(ctx, program, new_attr); @@ -982,8 +986,8 @@ struct SplitOpTranscriber : public OpTranscriber { op_inputs.push_back(axis_defining_info.value); } else { auto& attribute_translator = AttributeTranslator::instance(); - ir::Attribute new_attr = - attribute_translator("ir::Int32Attribute", op_desc.GetAttr("axis")); + pir::Attribute new_attr = + attribute_translator("pir::Int32Attribute", op_desc.GetAttr("axis")); auto sec_defin_op = InsertFullOperationForAttributeInput(ctx, program, new_attr); @@ -993,16 +997,16 @@ struct SplitOpTranscriber : public OpTranscriber { return op_inputs; } - ir::AttributeMap TranslateOpAttribute( - ir::IrContext* ctx, + pir::AttributeMap TranslateOpAttribute( + pir::IrContext* ctx, const std::string& normalized_op_name, const OpAttributeInfoList& op_attr_infos, const OpDesc& op_desc) override { int num = paddle::get(op_desc.GetAttr("num")); if (num > 0) { - ir::AttributeMap attribute_map = { + pir::AttributeMap attribute_map = { {"num", - ir::Int32Attribute::get(ctx, op_desc.GetAttrIfExists("num"))}, + pir::Int32Attribute::get(ctx, op_desc.GetAttrIfExists("num"))}, }; return attribute_map; @@ -1011,19 +1015,20 @@ struct SplitOpTranscriber : public OpTranscriber { return {}; } - ir::OpInfo LoopkUpOpInfo(ir::IrContext* ctx, const OpDesc& op_desc) override { + pir::OpInfo LoopkUpOpInfo(pir::IrContext* ctx, + const OpDesc& op_desc) override { int num = paddle::get(op_desc.GetAttr("num")); std::string target_op_name; if (num > 0) { - target_op_name = "pd.split_with_num"; + target_op_name = "pd_op.split_with_num"; } else { - target_op_name = "pd.split"; + target_op_name = "pd_op.split"; } const auto& op_info = ctx->GetRegisteredOpInfo(target_op_name); if (!op_info) { - IR_THROW("Op assign_value should have corresponding OpInfo pd.split"); + IR_THROW("Op assign_value should have corresponding OpInfo pd_op.split"); } return op_info; @@ -1031,10 +1036,10 @@ struct SplitOpTranscriber : public OpTranscriber { }; struct FetchOpTranscriber : public OpTranscriber { - ir::Operation* operator()(ir::IrContext* ctx, - TranslationContext* param_map, - const OpDesc& op_desc, - ir::Program* program) override { + pir::Operation* operator()(pir::IrContext* ctx, + TranslationContext* param_map, + const OpDesc& op_desc, + pir::Program* program) override { auto op_info = this->LoopkUpOpInfo(ctx, op_desc); auto* op_info_concept = @@ -1052,14 +1057,14 @@ struct FetchOpTranscriber : public OpTranscriber { ctx, param_map, op_desc, op_info.name(), input_infos, program); OpOutputTypeList op_output_types; - ir::AttributeMap attribute_map = { - {"name", ir::StrAttribute::get(ctx, op_desc.InputArgumentNames()[0])}, + pir::AttributeMap attribute_map = { + {"name", pir::StrAttribute::get(ctx, op_desc.InputArgumentNames()[0])}, {"col", - ir::Int32Attribute::get(ctx, op_desc.GetAttrIfExists("col"))}, + pir::Int32Attribute::get(ctx, op_desc.GetAttrIfExists("col"))}, }; op_output_types.push_back(op_inputs[0].type()); - ir::Operation* operation = ir::Operation::Create( + pir::Operation* operation = pir::Operation::Create( op_inputs, attribute_map, op_output_types, op_info); program->block()->push_back(operation); @@ -1068,13 +1073,13 @@ struct FetchOpTranscriber : public OpTranscriber { }; struct ShadowOutputOpTranscriber : public OpTranscriber { - ir::Operation* operator()(ir::IrContext* ctx, - TranslationContext* param_map, - const OpDesc& op_desc, - ir::Program* program) override { - auto op_info = ctx->GetRegisteredOpInfo(ir::SetParameterOp::name()); + pir::Operation* operator()(pir::IrContext* ctx, + TranslationContext* param_map, + const OpDesc& op_desc, + pir::Program* program) override { + auto op_info = ctx->GetRegisteredOpInfo(pir::SetParameterOp::name()); - std::vector op_inputs; + std::vector op_inputs; auto legacy_input_vars = op_desc.Input("x", true); auto defining_info = (*param_map)[legacy_input_vars[0]]; @@ -1086,14 +1091,14 @@ struct ShadowOutputOpTranscriber : public OpTranscriber { op_inputs.push_back(defining_info.value); - ir::AttributeMap attribute_map = { + pir::AttributeMap attribute_map = { {"parameter_name", - ir::StrAttribute::get(ctx, - op_desc.GetAttrIfExists("name"))}, + pir::StrAttribute::get(ctx, + op_desc.GetAttrIfExists("name"))}, }; - ir::Operation* operation = - ir::Operation::Create(op_inputs, attribute_map, {}, op_info); + pir::Operation* operation = + pir::Operation::Create(op_inputs, attribute_map, {}, op_info); program->block()->push_back(operation); return operation; @@ -1102,7 +1107,8 @@ struct ShadowOutputOpTranscriber : public OpTranscriber { // NOTE, add_n op in legacy ops don't have a kernel, so we use a new op for now struct AddNOpTranscriber : public OpTranscriber { - ir::OpInfo LoopkUpOpInfo(ir::IrContext* ctx, const OpDesc& op_desc) override { + pir::OpInfo LoopkUpOpInfo(pir::IrContext* ctx, + const OpDesc& op_desc) override { std::string target_op_name = kTargetDialectPrefix + OpNameCompatibleMapping(op_desc.Type()); if (IsInplace(op_desc)) { @@ -1120,18 +1126,20 @@ struct AddNOpTranscriber : public OpTranscriber { }; struct TrilAndTriuOpTranscriber : public OpTranscriber { - ir::OpInfo LoopkUpOpInfo(ir::IrContext* ctx, const OpDesc& op_desc) override { + pir::OpInfo LoopkUpOpInfo(pir::IrContext* ctx, + const OpDesc& op_desc) override { bool lower = PADDLE_GET_CONST(bool, op_desc.GetAttr("lower")); std::string target_op_name = ""; if (lower) { - target_op_name = "pd.tril"; + target_op_name = "pd_op.tril"; } else { - target_op_name = "pd.triu"; + target_op_name = "pd_op.triu"; } const auto& op_info = ctx->GetRegisteredOpInfo(target_op_name); if (!op_info) { IR_THROW( - "Op tril_triu should have corresponding OpInfo pd.tril or pd.triu."); + "Op tril_triu should have corresponding OpInfo pd_op.tril or " + "pd_op.triu."); } return op_info; @@ -1139,27 +1147,28 @@ struct TrilAndTriuOpTranscriber : public OpTranscriber { }; struct FillConstant2FullTranscriber : public OpTranscriber { - ir::OpInfo LoopkUpOpInfo(ir::IrContext* ctx, const OpDesc& op_desc) override { + pir::OpInfo LoopkUpOpInfo(pir::IrContext* ctx, + const OpDesc& op_desc) override { const auto& op_info = ctx->GetRegisteredOpInfo(dialect::FullOp::name()); if (!op_info) { - IR_THROW("Op fill_constant should have corresponding OpInfo pd.full"); + IR_THROW("Op fill_constant should have corresponding OpInfo pd_op.full"); } return op_info; } - std::vector GenerateOperationInput( - ir::IrContext* ctx, + std::vector GenerateOperationInput( + pir::IrContext* ctx, TranslationContext* param_map, const OpDesc& op_desc, const std::string& normalized_op_name, const OpInputInfoList& input_infos, - ir::Program* program) override { + pir::Program* program) override { return {}; } - ir::AttributeMap TranslateOpAttribute( - ir::IrContext* ctx, + pir::AttributeMap TranslateOpAttribute( + pir::IrContext* ctx, const std::string& normalized_op_name, const OpAttributeInfoList& op_attr_infos, const OpDesc& op_desc) override { @@ -1168,9 +1177,9 @@ struct FillConstant2FullTranscriber : public OpTranscriber { float value = PADDLE_GET_CONST(float, op_desc.GetAttr("value")); int dtype = PADDLE_GET_CONST(int, op_desc.GetAttr("dtype")); - auto attr_value = ir::FloatAttribute::get(ctx, value); + auto attr_value = pir::FloatAttribute::get(ctx, value); - ir::AttributeMap attribute_map = { + pir::AttributeMap attribute_map = { {"shape", attribute_translator("paddle::dialect::IntArrayAttribute", shape_attr)}, @@ -1181,14 +1190,6 @@ struct FillConstant2FullTranscriber : public OpTranscriber { paddle::dialect::VarTypeToDataType( static_cast(dtype)))}}; - if (op_desc.HasAttr("force_cpu")) { - bool force_cpu = PADDLE_GET_CONST(bool, op_desc.GetAttr("force_cpu")); - if (force_cpu) { - attribute_map["place"] = - paddle::dialect::PlaceAttribute::get(ctx, phi::CPUPlace()); - } - } - int place_type = PADDLE_GET_CONST(int, op_desc.GetAttr("place_type")); switch (place_type) { case -1: @@ -1212,30 +1213,40 @@ struct FillConstant2FullTranscriber : public OpTranscriber { paddle::dialect::PlaceAttribute::get(ctx, phi::XPUPlace()); break; } + + if (op_desc.HasAttr("force_cpu")) { + bool force_cpu = PADDLE_GET_CONST(bool, op_desc.GetAttr("force_cpu")); + if (force_cpu) { + attribute_map["place"] = + paddle::dialect::PlaceAttribute::get(ctx, phi::CPUPlace()); + } + } + return attribute_map; } }; struct FillConstant2FullWithTensorTranscriber : public OpTranscriber { - ir::OpInfo LoopkUpOpInfo(ir::IrContext* ctx, const OpDesc& op_desc) override { - const auto& op_info = ctx->GetRegisteredOpInfo("pd.full_with_tensor"); + pir::OpInfo LoopkUpOpInfo(pir::IrContext* ctx, + const OpDesc& op_desc) override { + const auto& op_info = ctx->GetRegisteredOpInfo("pd_op.full_with_tensor"); if (!op_info) { IR_THROW( "Op fill_constant should have corresponding OpInfo " - "pd.full_with_tensor"); + "pd_op.full_with_tensor"); } return op_info; } - std::vector GenerateOperationInput( - ir::IrContext* ctx, + std::vector GenerateOperationInput( + pir::IrContext* ctx, TranslationContext* param_map, const OpDesc& op_desc, const std::string& normalized_op_name, const OpInputInfoList& input_infos, - ir::Program* program) override { - std::vector op_inputs; + pir::Program* program) override { + std::vector op_inputs; if (op_desc.HasInput("ShapeTensor", true) && op_desc.Input("ShapeTensor", true).size() > 0) { auto shape_tensor_vars = op_desc.Input("ShapeTensor", true); @@ -1250,7 +1261,7 @@ struct FillConstant2FullWithTensorTranscriber : public OpTranscriber { } else { auto& attribute_translator = AttributeTranslator::instance(); paddle::framework::Attribute shape_attr = op_desc.GetAttr("shape"); - ir::Attribute new_attr = attribute_translator( + pir::Attribute new_attr = attribute_translator( "paddle::dialect::IntArrayAttribute", shape_attr); auto defining_op = InsertFullArrayOperationForAttributeInput(ctx, program, new_attr); @@ -1264,7 +1275,7 @@ struct FillConstant2FullWithTensorTranscriber : public OpTranscriber { op_inputs.push_back(defining_info.value); } else { float value = PADDLE_GET_CONST(float, op_desc.GetAttr("value")); - ir::Attribute new_attr = ir::FloatAttribute::get(ctx, value); + pir::Attribute new_attr = pir::FloatAttribute::get(ctx, value); auto defining_op = InsertFullOperationForAttributeInput(ctx, program, new_attr); op_inputs.push_back(defining_op->result(0)); @@ -1272,14 +1283,14 @@ struct FillConstant2FullWithTensorTranscriber : public OpTranscriber { return op_inputs; } - ir::AttributeMap TranslateOpAttribute( - ir::IrContext* ctx, + pir::AttributeMap TranslateOpAttribute( + pir::IrContext* ctx, const std::string& normalized_op_name, const OpAttributeInfoList& op_attr_infos, const OpDesc& op_desc) override { int dtype = PADDLE_GET_CONST(int, op_desc.GetAttr("dtype")); - ir::AttributeMap attribute_map = { + pir::AttributeMap attribute_map = { {"dtype", paddle::dialect::DataTypeAttribute::get( ctx, @@ -1290,10 +1301,10 @@ struct FillConstant2FullWithTensorTranscriber : public OpTranscriber { }; struct FillConstantTranscriber : public OpTranscriber { - ir::Operation* operator()(ir::IrContext* ctx, - TranslationContext* param_map, - const OpDesc& op_desc, - ir::Program* program) override { + pir::Operation* operator()(pir::IrContext* ctx, + TranslationContext* param_map, + const OpDesc& op_desc, + pir::Program* program) override { bool has_mutable_attribute = op_desc.HasInput("ShapeTensor", true) && op_desc.Input("ShapeTensor", true).size() > 0; has_mutable_attribute |= op_desc.HasInput("ShapeTensorList", true) && @@ -1310,12 +1321,13 @@ struct FillConstantTranscriber : public OpTranscriber { } }; -ir::OpResult TranslateNumClassesForOneHot(ir::IrContext* ctx, - TranslationContext* param_map, - const OpDesc& op_desc, - const std::string& normalized_op_name, - const OpInputInfo& input_info, - ir::Program* program) { +pir::OpResult TranslateNumClassesForOneHot( + pir::IrContext* ctx, + TranslationContext* param_map, + const OpDesc& op_desc, + const std::string& normalized_op_name, + const OpInputInfo& input_info, + pir::Program* program) { const std::string legacy_attr_name = "depth"; const std::string legacy_tensor_name = "depth_tensor"; std::vector legacy_vars; @@ -1343,9 +1355,9 @@ ir::OpResult TranslateNumClassesForOneHot(ir::IrContext* ctx, paddle::framework::Attribute legacy_attr = op_desc.GetAttr(legacy_attr_name); VLOG(10) << "[" << op_desc.Type() << "][attribute]" << " name: " << legacy_attr_name << " " << legacy_attr.index(); - ir::Attribute new_attr = attribute_translator(legacy_attr); + pir::Attribute new_attr = attribute_translator(legacy_attr); - ir::Operation* defining_op = + pir::Operation* defining_op = InsertFullOperationForAttributeInput(ctx, program, new_attr); return defining_op->result(0); } @@ -1360,16 +1372,16 @@ struct OneHotTranscriber : public OpTranscriber { }; }; -ir::Attribute TranslateReduceAll(ir::IrContext* ctx, - const OpDesc& op_desc, - const OpAttributeInfo& attr_info) { +pir::Attribute TranslateReduceAll(pir::IrContext* ctx, + const OpDesc& op_desc, + const OpAttributeInfo& attr_info) { bool reduce_all = false; if (op_desc.HasAttr("reduce_all")) { reduce_all = paddle::get(op_desc.GetAttr("reduce_all")); } if (reduce_all) { - return ir::ArrayAttribute::get(ctx, std::vector{}); + return pir::ArrayAttribute::get(ctx, std::vector{}); } auto& attribute_translator = AttributeTranslator::instance(); @@ -1391,13 +1403,13 @@ struct ReduceOpTranscriber : public OpTranscriber { }; struct ElementwiseTranscriber : public OpTranscriber { - std::vector GenerateOperationInput( - ir::IrContext* ctx, + std::vector GenerateOperationInput( + pir::IrContext* ctx, TranslationContext* param_map, const OpDesc& op_desc, const std::string& normalized_op_name, const OpInputInfoList& input_infos, - ir::Program* program) override { + pir::Program* program) override { int axis = paddle::get(op_desc.GetAttr("axis")); if (axis == -1) { @@ -1421,12 +1433,12 @@ struct ElementwiseTranscriber : public OpTranscriber { ctx, param_map, program, x_defining_info, x_name); x_defining_info = param_map->at(x_name); } - ir::OpResult x_value = x_defining_info.value; + pir::OpResult x_value = x_defining_info.value; IR_ENFORCE(x_value, "Expected op[%s]'s input %s is not null", op_desc.Type(), x_name); - ir::Type x_type = x_value.type(); + pir::Type x_type = x_value.type(); IR_ENFORCE(x_type.isa(), "Expected op[%s]'s input %s is DenseTensor but got %s", op_desc.Type(), @@ -1452,12 +1464,12 @@ struct ElementwiseTranscriber : public OpTranscriber { ctx, param_map, program, y_defining_info, y_name); y_defining_info = param_map->at(y_name); } - ir::OpResult y_value = y_defining_info.value; + pir::OpResult y_value = y_defining_info.value; IR_ENFORCE(y_value, "Expected op[%s]'s input %s is not null", op_desc.Type(), y_name); - ir::Type y_type = y_value.type(); + pir::Type y_type = y_value.type(); IR_ENFORCE(y_type.isa(), "Expected op[%s]'s input %s is DenseTensor but got %s", op_desc.Type(), @@ -1482,8 +1494,8 @@ struct ElementwiseTranscriber : public OpTranscriber { axis, append_size); - ir::Builder builder(ctx, program->block()); - ir::OpResult y_new; + pir::Builder builder(ctx, program->block()); + pir::OpResult y_new; if (std::find(y_shape.begin(), y_shape.end(), -1) == y_shape.end()) { std::vector y_new_shape(y_shape); for (int i = 0; i <= append_size; i++) { @@ -1500,8 +1512,8 @@ struct ElementwiseTranscriber : public OpTranscriber { std::vector(append_size, 1), phi::DataType::INT64, phi::CPUPlace()); - auto y_true_shape_op = builder.Build( - std::vector{shape_op.out(), append_shape_op.out()}); + auto y_true_shape_op = builder.Build( + std::vector{shape_op.out(), append_shape_op.out()}); auto concat_op = builder.Build(y_true_shape_op.out(), 0); auto y_new_shape = concat_op.out(); @@ -1513,12 +1525,14 @@ struct ElementwiseTranscriber : public OpTranscriber { }; struct GradAddOpTranscriber : public ElementwiseTranscriber { - ir::OpInfo LoopkUpOpInfo(ir::IrContext* ctx, const OpDesc& op_desc) override { - const std::string& target_op_name = "pd.add"; + pir::OpInfo LoopkUpOpInfo(pir::IrContext* ctx, + const OpDesc& op_desc) override { + const std::string& target_op_name = "pd_op.add"; const auto& op_info = ctx->GetRegisteredOpInfo(target_op_name); if (!op_info) { IR_THROW( - "Op assign_value should have corresponding OpInfo pd.assign_value_"); + "Op assign_value should have corresponding OpInfo " + "pd_op.assign_value_"); } return op_info; @@ -1526,10 +1540,10 @@ struct GradAddOpTranscriber : public ElementwiseTranscriber { }; struct ElementwiseGradTranscriber : public OpTranscriber { - void RecordOpResultMapping(ir::IrContext* ctx, + void RecordOpResultMapping(pir::IrContext* ctx, TranslationContext* param_map, const OpDesc& op_desc, - ir::Operation* operation, + pir::Operation* operation, const OpOutputMapping& arg_to_idx) override { OpTranscriber::RecordOpResultMapping( ctx, param_map, op_desc, operation, arg_to_idx); @@ -1566,12 +1580,12 @@ struct ElementwiseGradTranscriber : public OpTranscriber { op_desc.Type(), y_name); auto y_defining_info = param_map->at(y_name); - ir::OpResult y_value = y_defining_info.value; + pir::OpResult y_value = y_defining_info.value; IR_ENFORCE(y_value, "Expected op[%s]'s input %s is not null", op_desc.Type(), y_name); - ir::Type y_type = y_value.type(); + pir::Type y_type = y_value.type(); IR_ENFORCE(y_type.isa(), "Expected op[%s]'s input %s is DenseTensor but got %s", op_desc.Type(), @@ -1581,8 +1595,8 @@ struct ElementwiseGradTranscriber : public OpTranscriber { y_type.dyn_cast(); std::vector y_shape = phi::vectorize(y_tensor_type.dims()); - ir::OpResult value = operation->result(idx_in_op); - ir::Builder builder(ctx, operation->GetParent()); + pir::OpResult value = operation->result(idx_in_op); + pir::Builder builder(ctx, operation->GetParent()); auto reshape_op = builder.Build(value, y_shape); (*param_map)[y_grad_var_name] = VariableDefiningInfo(reshape_op.out(), false, -1); @@ -1590,10 +1604,10 @@ struct ElementwiseGradTranscriber : public OpTranscriber { }; struct SetValueOpTranscriber : public OpTranscriber { - ir::OpResult GetAttributeAsInput(ir::IrContext* ctx, - ir::Program* program, - const OpDesc& op_desc, - const OpInputInfo& input_info) override { + pir::OpResult GetAttributeAsInput(pir::IrContext* ctx, + pir::Program* program, + const OpDesc& op_desc, + const OpInputInfo& input_info) override { auto& attribute_translator = AttributeTranslator::instance(); auto& op_normalizer = OpNameNormalizer::instance(); @@ -1608,23 +1622,24 @@ struct SetValueOpTranscriber : public OpTranscriber { framework::Attribute legacy_attr = op_desc.GetAttr(legacy_attr_name); VLOG(10) << "[" << op_desc.Type() << "][attribute]" << " name: " << legacy_attr_name << " " << legacy_attr.index(); - ir::Attribute new_attr = + pir::Attribute new_attr = attribute_translator("paddle::dialect::IntArrayAttribute", legacy_attr); - ir::Operation* defining_op = + pir::Operation* defining_op = InsertFullArrayOperationForAttributeInput(ctx, program, new_attr); return defining_op->result(0); } }; struct SetValueWithTensorOpTranscriber : public SetValueOpTranscriber { - ir::OpInfo LoopkUpOpInfo(ir::IrContext* ctx, const OpDesc& op_desc) override { + pir::OpInfo LoopkUpOpInfo(pir::IrContext* ctx, + const OpDesc& op_desc) override { std::string target_op_name = dialect::SetValueWithTensorOp::name(); const auto& op_info = ctx->GetRegisteredOpInfo(target_op_name); if (!op_info) { IR_THROW( "Op set_value should have corresponding OpInfo " - "pd.set_value_with_tensor"); + "pd_op.set_value_with_tensor"); } return op_info; @@ -1635,12 +1650,12 @@ struct SetValueWithTensorOpTranscriber : public SetValueOpTranscriber { if (input_name != "values") { return nullptr; } - return [](ir::IrContext* ctx, + return [](pir::IrContext* ctx, TranslationContext* param_map, const OpDesc& op_desc, const std::string&, const OpInputInfo& info, - ir::Program* program) -> ir::OpResult { + pir::Program* program) -> pir::OpResult { std::vector legacy_input_vars; IR_ENFORCE(op_desc.HasInput("ValueTensor"), "[set_value] should have ValueTensor"); @@ -1662,13 +1677,14 @@ struct SetValueWithTensorOpTranscriber : public SetValueOpTranscriber { }; struct SetValueGradOpTranscriber : public SetValueWithTensorOpTranscriber { - ir::OpInfo LoopkUpOpInfo(ir::IrContext* ctx, const OpDesc& op_desc) override { + pir::OpInfo LoopkUpOpInfo(pir::IrContext* ctx, + const OpDesc& op_desc) override { std::string target_op_name = dialect::SetValueGradOp::name(); const auto& op_info = ctx->GetRegisteredOpInfo(target_op_name); if (!op_info) { IR_THROW( "Op set_value_grad should have corresponding OpInfo " - "pd.set_value_grad"); + "pd_op.set_value_grad"); } return op_info; @@ -1676,10 +1692,10 @@ struct SetValueGradOpTranscriber : public SetValueWithTensorOpTranscriber { }; struct LegacySetValueDispatcher : public OpTranscriber { - ir::Operation* operator()(ir::IrContext* ctx, - TranslationContext* param_map, - const OpDesc& op_desc, - ir::Program* program) override { + pir::Operation* operator()(pir::IrContext* ctx, + TranslationContext* param_map, + const OpDesc& op_desc, + pir::Program* program) override { std::vector legacy_input_vars; // if op has input with name "ValueTensor", then use that input as value @@ -1698,8 +1714,8 @@ struct LegacySetValueDispatcher : public OpTranscriber { }; OpTranslator::OpTranslator() { - ir::IrContext* ctx = ir::IrContext::Instance(); - ctx->GetOrRegisterDialect(); + pir::IrContext* ctx = pir::IrContext::Instance(); + ctx->GetOrRegisterDialect(); general_handler = OpTranscriber(); special_handlers["add_n"] = AddNOpTranscriber(); diff --git a/paddle/fluid/ir_adaptor/translator/op_translator.h b/paddle/fluid/ir_adaptor/translator/op_translator.h index afc7566be12b3..2ae6643999b8d 100644 --- a/paddle/fluid/ir_adaptor/translator/op_translator.h +++ b/paddle/fluid/ir_adaptor/translator/op_translator.h @@ -20,12 +20,12 @@ #include "paddle/fluid/framework/block_desc.h" #include "paddle/fluid/framework/op_desc.h" #include "paddle/fluid/framework/var_desc.h" -#include "paddle/fluid/ir/dialect/paddle_dialect/interface/op_yaml_info.h" #include "paddle/fluid/ir_adaptor/translator/program_translator.h" -#include "paddle/ir/core/ir_context.h" -#include "paddle/ir/core/operation.h" -#include "paddle/ir/core/program.h" -#include "paddle/ir/core/value.h" +#include "paddle/fluid/pir/dialect/operator/interface/op_yaml_info.h" +#include "paddle/pir/core/ir_context.h" +#include "paddle/pir/core/operation.h" +#include "paddle/pir/core/program.h" +#include "paddle/pir/core/value.h" namespace paddle { namespace translator { @@ -41,7 +41,7 @@ struct OpTranscriber { using IdxInVector = size_t; using ResultIdx = std::tuple; using OpDesc = paddle::framework::OpDesc; - using OpOutputTypeList = std::vector; + using OpOutputTypeList = std::vector; using OpOutputMapping = std::unordered_map; using OpInputInfo = dialect::OpInputInfo; using OpInputInfoList = std::vector; @@ -49,51 +49,51 @@ struct OpTranscriber { using OpAttributeInfoList = std::vector; using OpOutputInfo = dialect::OpOutputInfo; using OpOutputInfoList = std::vector; - using InputHandlerFn = std::function; - using AttributeHandlerFn = std::function; + using InputHandlerFn = std::function; + using AttributeHandlerFn = std::function; public: - virtual ir::Operation* operator()(ir::IrContext* ctx, - TranslationContext* param_map, - const OpDesc& op_desc, - ir::Program* program); + virtual pir::Operation* operator()(pir::IrContext* ctx, + TranslationContext* param_map, + const OpDesc& op_desc, + pir::Program* program); public: - virtual ir::OpInfo LoopkUpOpInfo(ir::IrContext* ctx, const OpDesc& op_desc); - virtual std::vector GenerateOperationInput( - ir::IrContext* ctx, + virtual pir::OpInfo LoopkUpOpInfo(pir::IrContext* ctx, const OpDesc& op_desc); + virtual std::vector GenerateOperationInput( + pir::IrContext* ctx, TranslationContext* param_map, const OpDesc& op_desc, const std::string& normalized_op_name, const OpInputInfoList& input_infos, - ir::Program* program); + pir::Program* program); virtual std::tuple GenerateOperationOutput( - ir::IrContext* ctx, + pir::IrContext* ctx, const OpDesc& op_desc, const OpOutputInfoList& output_infos); - virtual void HandleNonexistentAttribute(ir::IrContext*, - ir::AttributeMap* attribute_map, + virtual void HandleNonexistentAttribute(pir::IrContext*, + pir::AttributeMap* attribute_map, const OpAttributeInfo& info); - virtual ir::AttributeMap TranslateOpAttribute( - ir::IrContext* ctx, + virtual pir::AttributeMap TranslateOpAttribute( + pir::IrContext* ctx, const std::string& normalized_op_name, const OpAttributeInfoList& op_attr_infos, const OpDesc& op_desc); - virtual ir::OpResult GetAttributeAsInput(ir::IrContext* ctx, - ir::Program* program, - const OpDesc& op_desc, - const OpInputInfo& input_info); + virtual pir::OpResult GetAttributeAsInput(pir::IrContext* ctx, + pir::Program* program, + const OpDesc& op_desc, + const OpInputInfo& input_info); - virtual void RecordOpResultMapping(ir::IrContext* ctx, + virtual void RecordOpResultMapping(pir::IrContext* ctx, TranslationContext* param_map, const OpDesc& op_desc, - ir::Operation* operation, + pir::Operation* operation, const OpOutputMapping& arg_to_idx); public: @@ -105,11 +105,11 @@ struct OpTranscriber { const std::string& input_name) { return nullptr; } - virtual void InsertSliceOperationForInput(ir::IrContext* ctx, + virtual void InsertSliceOperationForInput(pir::IrContext* ctx, TranslationContext* param_map, const OpDesc& op_desc, const OpInputInfoList& input_infos, - ir::Program* program); + pir::Program* program); }; class OpTranslator { @@ -118,8 +118,8 @@ class OpTranslator { using OpDesc = paddle::framework::OpDesc; using BlockDesc = paddle::framework::BlockDesc; using VarDesc = paddle::framework::VarDesc; - using OpTranslateFn = std::function; + using OpTranslateFn = std::function; private: OpTranslator(); // Disallow instantiation outside of the class. diff --git a/paddle/fluid/ir_adaptor/translator/program_translator.cc b/paddle/fluid/ir_adaptor/translator/program_translator.cc index 9065554781265..678a79a5540b8 100644 --- a/paddle/fluid/ir_adaptor/translator/program_translator.cc +++ b/paddle/fluid/ir_adaptor/translator/program_translator.cc @@ -17,21 +17,21 @@ #include #include "glog/logging.h" - #include "paddle/fluid/framework/program_desc.h" #include "paddle/fluid/framework/var_desc.h" +#include "paddle/fluid/ir_adaptor/translator/attribute_translator.h" #include "paddle/fluid/ir_adaptor/translator/op_translator.h" #include "paddle/fluid/ir_adaptor/translator/type_translator.h" #include "paddle/fluid/ir_adaptor/translator/utils.h" -#include "paddle/ir/core/attribute.h" -#include "paddle/ir/core/block.h" -#include "paddle/ir/core/builtin_attribute.h" -#include "paddle/ir/core/builtin_op.h" -#include "paddle/ir/core/builtin_type.h" -#include "paddle/ir/core/enforce.h" -#include "paddle/ir/core/operation.h" -#include "paddle/ir/core/value.h" #include "paddle/phi/core/enforce.h" +#include "paddle/pir/core/attribute.h" +#include "paddle/pir/core/block.h" +#include "paddle/pir/core/builtin_attribute.h" +#include "paddle/pir/core/builtin_op.h" +#include "paddle/pir/core/builtin_type.h" +#include "paddle/pir/core/enforce.h" +#include "paddle/pir/core/operation.h" +#include "paddle/pir/core/value.h" namespace paddle { namespace translator { @@ -46,9 +46,9 @@ const std::unordered_set ProgramTranslator::no_cast_var_names = { }; ProgramTranslator::ProgramTranslator(const ProgramDesc* legacy_program, - ir::Program* program) + pir::Program* program) : legacy_program_(legacy_program), program_(program) { - ctx_ = ir::IrContext::Instance(); + ctx_ = pir::IrContext::Instance(); } void ProgramTranslator::Translate() { @@ -84,31 +84,31 @@ void ProgramTranslator::Translate() { } } -inline ir::Operation* InsertGetParamaterOp(ir::IrContext* ctx, - const VarDesc* var) { +inline pir::Operation* InsertGetParamaterOp(pir::IrContext* ctx, + const VarDesc* var) { auto& type_translator = TypeTranslator::instance(); - std::string get_parameter_op_name(ir::GetParameterOp::name()); - ir::OpInfo op_info = ctx->GetRegisteredOpInfo(get_parameter_op_name); - std::unordered_map op_attribute_map = { - {"parameter_name", ir::StrAttribute::get(ctx, var->Name())}, + std::string get_parameter_op_name(pir::GetParameterOp::name()); + pir::OpInfo op_info = ctx->GetRegisteredOpInfo(get_parameter_op_name); + std::unordered_map op_attribute_map = { + {"parameter_name", pir::StrAttribute::get(ctx, var->Name())}, }; - ir::Type translated_var_type = type_translator[var->GetType()](ctx, *var); - ir::Operation* operation = ir::Operation::Create( + pir::Type translated_var_type = type_translator[var->GetType()](ctx, *var); + pir::Operation* operation = pir::Operation::Create( {}, op_attribute_map, {translated_var_type}, op_info); return operation; } -inline ir::Operation* InsertSetParamaterOp(ir::IrContext* ctx, - ir::OpResult defining_op_result, - const VarDesc* var) { - std::string set_parameter_op_name(ir::SetParameterOp::name()); - ir::OpInfo op_info = ctx->GetRegisteredOpInfo(set_parameter_op_name); - std::unordered_map op_attribute_map = { - {"parameter_name", ir::StrAttribute::get(ctx, var->Name())}, +inline pir::Operation* InsertSetParamaterOp(pir::IrContext* ctx, + pir::OpResult defining_op_result, + const VarDesc* var) { + std::string set_parameter_op_name(pir::SetParameterOp::name()); + pir::OpInfo op_info = ctx->GetRegisteredOpInfo(set_parameter_op_name); + std::unordered_map op_attribute_map = { + {"parameter_name", pir::StrAttribute::get(ctx, var->Name())}, }; - ir::Operation* operation = ir::Operation::Create( + pir::Operation* operation = pir::Operation::Create( {defining_op_result}, op_attribute_map, {}, op_info); return operation; } @@ -149,7 +149,7 @@ void ProgramTranslator::GetParameterForSingleBlock(const BlockDesc& block) { var_desc, phi::errors::PreconditionNotMet( "VarDesc of [%s] can not be nullptr", var_name)); - ir::Operation* op = InsertGetParamaterOp(ctx_, var_desc); + pir::Operation* op = InsertGetParamaterOp(ctx_, var_desc); program_->block()->push_back(op); param_map_[var_name] = VariableDefiningInfo(op->result(0)); VLOG(10) << "[op translated][get parameter]" << var_name; @@ -178,7 +178,7 @@ void ProgramTranslator::InsertOperationToSingleBlock(const BlockDesc& block) { continue; } } - ir::Operation* operation = fn(ctx_, ¶m_map_, *op, program_); + pir::Operation* operation = fn(ctx_, ¶m_map_, *op, program_); VLOG(10) << "[op translated][special]" << operation; } } @@ -203,7 +203,7 @@ void ProgramTranslator::SetParameterFromSingleBlock(const BlockDesc& block) { need_set_parameter_op &= (param_map_.count(var_name) != 0); need_set_parameter_op &= (!set_input_var_names.count(var_name)); if (need_set_parameter_op) { - ir::OpResult defining_op_result = param_map_[var_name].value; + pir::OpResult defining_op_result = param_map_[var_name].value; if (!defining_op_result) { continue; } @@ -214,11 +214,11 @@ void ProgramTranslator::SetParameterFromSingleBlock(const BlockDesc& block) { defining_op_result = param_map_.at(var_name).value; } - ir::Operation* op = InsertSetParamaterOp( + pir::Operation* op = InsertSetParamaterOp( ctx_, defining_op_result, parameter_name_mappings_[var_name]); - ir::Block* block = program_->block(); - ir::Block::iterator insert_pos = std::find( + pir::Block* block = program_->block(); + pir::Block::iterator insert_pos = std::find( block->begin(), block->end(), defining_op_result.owner()); IR_ENFORCE( @@ -249,7 +249,7 @@ void ProgramTranslator::SetStopGradientAttributeForAllValue( if (var == nullptr) { continue; } - ir::OpResult value = value_info.value; + pir::OpResult value = value_info.value; if (!value) { PADDLE_THROW(phi::errors::PreconditionNotMet( "Value of [%s] can not ber None", var_name)); @@ -261,19 +261,19 @@ void ProgramTranslator::SetStopGradientAttributeForAllValue( "Defining operator of [%s] can not be nullptr", var_name)); VLOG(8) << "[op translated][stop gradient]" << var_name << " from: " << defining_op->name(); - std::vector stop_gradients; + std::vector stop_gradients; if (defining_op->HasAttribute(kAttrStopGradients)) { stop_gradients = defining_op->attribute(kAttrStopGradients) - .dyn_cast() + .dyn_cast() .AsVector(); } else { - stop_gradients = std::vector( - defining_op->num_results(), ir::BoolAttribute::get(ctx_, false)); + stop_gradients = std::vector( + defining_op->num_results(), pir::BoolAttribute::get(ctx_, false)); } stop_gradients[value.GetResultIndex()] = - ir::BoolAttribute::get(ctx_, var->StopGradient()); + pir::BoolAttribute::get(ctx_, var->StopGradient()); defining_op->set_attribute(kAttrStopGradients, - ir::ArrayAttribute::get(ctx_, stop_gradients)); + pir::ArrayAttribute::get(ctx_, stop_gradients)); } } @@ -288,7 +288,7 @@ void ProgramTranslator::SetIsPersisableAttributeForAllValue( if (var == nullptr) { continue; } - ir::OpResult value = value_info.value; + pir::OpResult value = value_info.value; if (!value) { PADDLE_THROW(phi::errors::PreconditionNotMet( "Value of [%s] can not ber None", var_name)); @@ -300,19 +300,19 @@ void ProgramTranslator::SetIsPersisableAttributeForAllValue( "Defining operator of [%s] can not be nullptr", var_name)); VLOG(8) << "[op translated][is persisable]" << var_name << " from: " << defining_op->name(); - std::vector is_persisable; + std::vector is_persisable; if (defining_op->HasAttribute(kAttrIsPersisable)) { is_persisable = defining_op->attribute(kAttrIsPersisable) - .dyn_cast() + .dyn_cast() .AsVector(); } else { - is_persisable = std::vector( - defining_op->num_results(), ir::BoolAttribute::get(ctx_, false)); + is_persisable = std::vector( + defining_op->num_results(), pir::BoolAttribute::get(ctx_, false)); } is_persisable[value.GetResultIndex()] = - ir::BoolAttribute::get(ctx_, var->Persistable()); + pir::BoolAttribute::get(ctx_, var->Persistable()); defining_op->set_attribute(kAttrIsPersisable, - ir::ArrayAttribute::get(ctx_, is_persisable)); + pir::ArrayAttribute::get(ctx_, is_persisable)); } } diff --git a/paddle/fluid/ir_adaptor/translator/program_translator.h b/paddle/fluid/ir_adaptor/translator/program_translator.h index 88901376ae3cb..02ee94d7dd0cd 100644 --- a/paddle/fluid/ir_adaptor/translator/program_translator.h +++ b/paddle/fluid/ir_adaptor/translator/program_translator.h @@ -18,17 +18,18 @@ #include #include #include - +#include "paddle/fluid/framework/op_call_stack.h" +#include "paddle/fluid/framework/op_proto_maker.h" #include "paddle/fluid/framework/program_desc.h" -#include "paddle/ir/core/ir_context.h" -#include "paddle/ir/core/program.h" -#include "paddle/ir/core/value.h" +#include "paddle/pir/core/ir_context.h" +#include "paddle/pir/core/program.h" +#include "paddle/pir/core/value.h" namespace paddle { namespace translator { struct VariableDefiningInfo { - VariableDefiningInfo(ir::OpResult value, + VariableDefiningInfo(pir::OpResult value, bool generated_by_vector = false, int idx_in_vector = -1) : value(value), @@ -36,7 +37,7 @@ struct VariableDefiningInfo { idx_in_vector(idx_in_vector) {} VariableDefiningInfo() {} - ir::OpResult value; + pir::OpResult value; bool generated_by_vector = false; // true if target variable is generated by Vector @@ -54,14 +55,14 @@ class ProgramTranslator { public: explicit ProgramTranslator(const ProgramDesc* legacy_program, - ir::Program* program); + pir::Program* program); void Translate(); private: const ProgramDesc* legacy_program_; // not owned - ir::Program* program_; // not owned - ir::IrContext* ctx_; // not owned + pir::Program* program_; // not owned + pir::IrContext* ctx_; // not owned TranslationContext param_map_; std::unordered_map parameter_name_mappings_; diff --git a/paddle/fluid/ir_adaptor/translator/translate.cc b/paddle/fluid/ir_adaptor/translator/translate.cc index 87bef41641a5f..0f98e557743fc 100644 --- a/paddle/fluid/ir_adaptor/translator/translate.cc +++ b/paddle/fluid/ir_adaptor/translator/translate.cc @@ -17,20 +17,20 @@ #include #include "paddle/fluid/framework/program_desc.h" -#include "paddle/fluid/ir/dialect/paddle_dialect/ir/pd_dialect.h" #include "paddle/fluid/ir_adaptor/translator/program_translator.h" -#include "paddle/ir/core/builtin_dialect.h" -#include "paddle/ir/core/program.h" +#include "paddle/fluid/pir/dialect/operator/ir/op_dialect.h" +#include "paddle/pir/core/builtin_dialect.h" +#include "paddle/pir/core/program.h" namespace paddle { using LegacyProgramDesc = ::paddle::framework::ProgramDesc; -using Program = ::ir::Program; +using Program = pir::Program; std::unique_ptr TranslateLegacyProgramToProgram( const LegacyProgramDesc& legacy_program) { - ir::IrContext* ctx = ir::IrContext::Instance(); - ctx->GetOrRegisterDialect(); + pir::IrContext* ctx = pir::IrContext::Instance(); + ctx->GetOrRegisterDialect(); auto program = std::make_unique(ctx); translator::ProgramTranslator program_translator(&legacy_program, program.get()); diff --git a/paddle/fluid/ir_adaptor/translator/translate.h b/paddle/fluid/ir_adaptor/translator/translate.h index 8f604a47761fc..47ad12003f807 100644 --- a/paddle/fluid/ir_adaptor/translator/translate.h +++ b/paddle/fluid/ir_adaptor/translator/translate.h @@ -17,12 +17,12 @@ #include #include "paddle/fluid/framework/program_desc.h" -#include "paddle/fluid/ir/dialect/paddle_dialect/ir/pd_dialect.h" -#include "paddle/ir/core/program.h" +#include "paddle/fluid/pir/dialect/operator/ir/op_dialect.h" +#include "paddle/pir/core/program.h" namespace paddle { -std::unique_ptr<::ir::Program> TranslateLegacyProgramToProgram( +std::unique_ptr<::pir::Program> TranslateLegacyProgramToProgram( const ::paddle::framework::ProgramDesc& legacy_program); } // namespace paddle diff --git a/paddle/fluid/ir_adaptor/translator/type_translator.cc b/paddle/fluid/ir_adaptor/translator/type_translator.cc index 5c3cbdbc240ce..ef1dbf543c671 100644 --- a/paddle/fluid/ir_adaptor/translator/type_translator.cc +++ b/paddle/fluid/ir_adaptor/translator/type_translator.cc @@ -15,9 +15,9 @@ #include "paddle/fluid/ir_adaptor/translator/type_translator.h" #include "paddle/fluid/framework/framework.pb.h" -#include "paddle/fluid/ir/dialect/paddle_dialect/ir/pd_type.h" -#include "paddle/fluid/ir/dialect/paddle_dialect/ir/pd_type_storage.h" -#include "paddle/ir/core/builtin_type.h" +#include "paddle/fluid/pir/dialect/operator/ir/op_type.h" +#include "paddle/fluid/pir/dialect/operator/ir/type_storage.h" +#include "paddle/pir/core/builtin_type.h" namespace paddle { namespace translator { @@ -34,59 +34,59 @@ using SelectedRowsTypeStorage = paddle::dialect::SelectedRowsTypeStorage; TypeTranslator::TypeTranslator() { handlers = { {VarType::BOOL, - [&](ir::IrContext* ctx, const VarDesc& var_desc) -> ir::Type { - return ir::BoolType::get(ctx); + [&](pir::IrContext* ctx, const VarDesc& var_desc) -> pir::Type { + return pir::BoolType::get(ctx); }}, {VarType::UINT8, - [&](ir::IrContext* ctx, const VarDesc& var_desc) -> ir::Type { - return ir::UInt8Type::get(ctx); + [&](pir::IrContext* ctx, const VarDesc& var_desc) -> pir::Type { + return pir::UInt8Type::get(ctx); }}, {VarType::INT8, - [&](ir::IrContext* ctx, const VarDesc& var_desc) -> ir::Type { - return ir::Int8Type::get(ctx); + [&](pir::IrContext* ctx, const VarDesc& var_desc) -> pir::Type { + return pir::Int8Type::get(ctx); }}, {VarType::INT16, - [&](ir::IrContext* ctx, const VarDesc& var_desc) -> ir::Type { - return ir::Int16Type::get(ctx); + [&](pir::IrContext* ctx, const VarDesc& var_desc) -> pir::Type { + return pir::Int16Type::get(ctx); }}, {VarType::INT32, - [&](ir::IrContext* ctx, const VarDesc& var_desc) -> ir::Type { - return ir::Int32Type::get(ctx); + [&](pir::IrContext* ctx, const VarDesc& var_desc) -> pir::Type { + return pir::Int32Type::get(ctx); }}, {VarType::INT64, - [&](ir::IrContext* ctx, const VarDesc& var_desc) -> ir::Type { - return ir::Int64Type::get(ctx); + [&](pir::IrContext* ctx, const VarDesc& var_desc) -> pir::Type { + return pir::Int64Type::get(ctx); }}, {VarType::FP16, - [&](ir::IrContext* ctx, const VarDesc& var_desc) -> ir::Type { - return ir::Float16Type::get(ctx); + [&](pir::IrContext* ctx, const VarDesc& var_desc) -> pir::Type { + return pir::Float16Type::get(ctx); }}, {VarType::FP32, - [&](ir::IrContext* ctx, const VarDesc& var_desc) -> ir::Type { - return ir::Float32Type::get(ctx); + [&](pir::IrContext* ctx, const VarDesc& var_desc) -> pir::Type { + return pir::Float32Type::get(ctx); }}, {VarType::FP64, - [&](ir::IrContext* ctx, const VarDesc& var_desc) -> ir::Type { - return ir::Float64Type::get(ctx); + [&](pir::IrContext* ctx, const VarDesc& var_desc) -> pir::Type { + return pir::Float64Type::get(ctx); }}, {VarType::BF16, - [&](ir::IrContext* ctx, const VarDesc& var_desc) -> ir::Type { - return ir::BFloat16Type::get(ctx); + [&](pir::IrContext* ctx, const VarDesc& var_desc) -> pir::Type { + return pir::BFloat16Type::get(ctx); }}, {VarType::COMPLEX64, - [&](ir::IrContext* ctx, const VarDesc& var_desc) -> ir::Type { - return ir::Complex64Type::get(ctx); + [&](pir::IrContext* ctx, const VarDesc& var_desc) -> pir::Type { + return pir::Complex64Type::get(ctx); }}, {VarType::COMPLEX128, - [&](ir::IrContext* ctx, const VarDesc& var_desc) -> ir::Type { - return ir::Complex128Type::get(ctx); + [&](pir::IrContext* ctx, const VarDesc& var_desc) -> pir::Type { + return pir::Complex128Type::get(ctx); }}, {VarType::LOD_TENSOR, - [&](ir::IrContext* ctx, const VarDesc& var_desc) -> ir::Type { + [&](pir::IrContext* ctx, const VarDesc& var_desc) -> pir::Type { VLOG(10) << "[vartype translating]" << "[" << var_desc.Name() << "] from LOD_TENSOR"; - ir::Type dtype = + pir::Type dtype = this->operator[](var_desc.GetDataType())(ctx, var_desc); DenseTensorTypeStorage::Dim dim = phi::make_ddim(var_desc.GetShape()); DenseTensorTypeStorage::DataLayout layout = @@ -96,18 +96,18 @@ TypeTranslator::TypeTranslator() { return DenseTensorType::get(ctx, dtype, dim, layout, lod, offset); }}, {VarType::LOD_TENSOR_ARRAY, - [&](ir::IrContext* ctx, const VarDesc& var_desc) -> ir::Type { + [&](pir::IrContext* ctx, const VarDesc& var_desc) -> pir::Type { VLOG(10) << "[vartype translating]" << "[" << var_desc.Name() << "] from LOD_TENSOR_ARRAY"; - return ir::VectorType::get(ctx, std::vector{}); + return pir::VectorType::get(ctx, std::vector{}); }}, {VarType::SELECTED_ROWS, - [&](ir::IrContext* ctx, const VarDesc& var_desc) -> ir::Type { + [&](pir::IrContext* ctx, const VarDesc& var_desc) -> pir::Type { VLOG(10) << "[vartype translating]" << "[" << var_desc.Name() << "] from SELECTED_ROWS"; - ir::Type dtype = + pir::Type dtype = this->operator[](var_desc.GetDataType())(ctx, var_desc); SelectedRowsTypeStorage::Dim dim = phi::make_ddim(var_desc.GetShape()); @@ -115,7 +115,7 @@ TypeTranslator::TypeTranslator() { SelectedRowsTypeStorage::DataLayout::UNDEFINED; SelectedRowsTypeStorage::LoD lod = {}; size_t offset = 0; - ir::Type SelectedRows = + pir::Type SelectedRows = SelectedRowsType::get(ctx, dtype, dim, layout, lod, offset); return SelectedRows; }}, diff --git a/paddle/fluid/ir_adaptor/translator/type_translator.h b/paddle/fluid/ir_adaptor/translator/type_translator.h index d93be9a9db371..255795c92d807 100644 --- a/paddle/fluid/ir_adaptor/translator/type_translator.h +++ b/paddle/fluid/ir_adaptor/translator/type_translator.h @@ -20,15 +20,15 @@ #include "paddle/fluid/framework/op_desc.h" #include "paddle/fluid/framework/var_desc.h" #include "paddle/fluid/ir_adaptor/translator/program_translator.h" -#include "paddle/ir/core/builtin_type.h" -#include "paddle/ir/core/dialect.h" -#include "paddle/ir/core/ir_context.h" +#include "paddle/pir/core/builtin_type.h" +#include "paddle/pir/core/dialect.h" +#include "paddle/pir/core/ir_context.h" namespace paddle { namespace translator { using TypeTranslateFn = - std::function; + std::function; class TypeTranslator { public: diff --git a/paddle/fluid/ir_adaptor/translator/utils.cc b/paddle/fluid/ir_adaptor/translator/utils.cc index 38f3f5fd8c90b..4a591eeedf083 100644 --- a/paddle/fluid/ir_adaptor/translator/utils.cc +++ b/paddle/fluid/ir_adaptor/translator/utils.cc @@ -16,43 +16,43 @@ #include -#include "paddle/fluid/ir/dialect/paddle_dialect/ir/pd_dialect.h" #include "paddle/fluid/ir_adaptor/translator/op_translator.h" -#include "paddle/ir/core/builtin_attribute.h" -#include "paddle/ir/core/builtin_type.h" -#include "paddle/ir/core/enforce.h" -#include "paddle/ir/core/utils.h" +#include "paddle/fluid/pir/dialect/operator/ir/op_dialect.h" +#include "paddle/pir/core/builtin_attribute.h" +#include "paddle/pir/core/builtin_type.h" +#include "paddle/pir/core/enforce.h" +#include "paddle/pir/core/utils.h" namespace paddle { namespace translator { -ir::Operation* InsertSliceOperationForTarget( - ir::IrContext* ctx, +pir::Operation* InsertSliceOperationForTarget( + pir::IrContext* ctx, TranslationContext* param_map, - ir::Program* program, + pir::Program* program, const VariableDefiningInfo& defining_info, const std::string& arg_name) { - std::string slice_op_name(ir::SliceOp::name()); - ir::OpInfo op_info = ctx->GetRegisteredOpInfo(slice_op_name); - std::unordered_map op_attribute_map = { - {"index", ir::Int32Attribute::get(ctx, defining_info.idx_in_vector)}, + std::string slice_op_name(pir::SliceOp::name()); + pir::OpInfo op_info = ctx->GetRegisteredOpInfo(slice_op_name); + std::unordered_map op_attribute_map = { + {"index", pir::Int32Attribute::get(ctx, defining_info.idx_in_vector)}, }; - ir::VectorType src_vec_type = - defining_info.value.type().dyn_cast(); - ir::Operation* operation = - ir::Operation::Create({defining_info.value}, - op_attribute_map, - {src_vec_type[defining_info.idx_in_vector]}, - op_info); + pir::VectorType src_vec_type = + defining_info.value.type().dyn_cast(); + pir::Operation* operation = + pir::Operation::Create({defining_info.value}, + op_attribute_map, + {src_vec_type[defining_info.idx_in_vector]}, + op_info); program->block()->push_back(operation); - ir::OpResult target_op_result = operation->result(0); + pir::OpResult target_op_result = operation->result(0); (*param_map)[arg_name] = VariableDefiningInfo(target_op_result); return operation; } std::ostream& operator<<(std::ostream& os, const std::vector& vec_str) { - ir::PrintInterleave( + pir::PrintInterleave( vec_str.begin(), vec_str.end(), [&os](std::string s) { os << s; }, @@ -61,7 +61,7 @@ std::ostream& operator<<(std::ostream& os, } std::vector CheckUnregisteredOperationInBlock( - ir::IrContext* ctx, const framework::BlockDesc& block) { + pir::IrContext* ctx, const framework::BlockDesc& block) { auto& op_translator = OpTranslator::instance(); std::vector unregistered_ops; for (auto op : block.AllOps()) { @@ -71,7 +71,7 @@ std::vector CheckUnregisteredOperationInBlock( OpTranscriber general_handler; try { general_handler.LoopkUpOpInfo(ctx, *op); - } catch (ir::IrNotMetException& e) { + } catch (pir::IrNotMetException& e) { unregistered_ops.push_back(op->Type()); } } @@ -79,8 +79,8 @@ std::vector CheckUnregisteredOperationInBlock( } std::vector CheckUnregisteredOperation( - ir::IrContext* ctx, const framework::ProgramDesc& legacy_program) { - ctx->GetOrRegisterDialect(); + pir::IrContext* ctx, const framework::ProgramDesc& legacy_program) { + ctx->GetOrRegisterDialect(); std::vector unregistered_ops; for (size_t block_idx = 0; block_idx < legacy_program.Size(); block_idx++) { diff --git a/paddle/fluid/ir_adaptor/translator/utils.h b/paddle/fluid/ir_adaptor/translator/utils.h index 20e462b5bbde1..63bbde06d2ec0 100644 --- a/paddle/fluid/ir_adaptor/translator/utils.h +++ b/paddle/fluid/ir_adaptor/translator/utils.h @@ -19,17 +19,17 @@ #include "paddle/fluid/framework/program_desc.h" #include "paddle/fluid/ir_adaptor/translator/program_translator.h" -#include "paddle/ir/core/ir_context.h" -#include "paddle/ir/core/operation.h" -#include "paddle/ir/core/program.h" +#include "paddle/pir/core/ir_context.h" +#include "paddle/pir/core/operation.h" +#include "paddle/pir/core/program.h" namespace paddle { namespace translator { -ir::Operation* InsertSliceOperationForTarget( - ir::IrContext* ctx, +pir::Operation* InsertSliceOperationForTarget( + pir::IrContext* ctx, TranslationContext* param_map, - ir::Program* program, + pir::Program* program, const VariableDefiningInfo& defining_info, const std::string& arg_name); @@ -37,7 +37,7 @@ std::ostream& operator<<(std::ostream& os, const std::vector& vec_str); std::vector CheckUnregisteredOperation( - ir::IrContext* ctx, const framework::ProgramDesc& legacy_program); + pir::IrContext* ctx, const framework::ProgramDesc& legacy_program); } // namespace translator } // namespace paddle diff --git a/paddle/fluid/jit/engine/interpreter_engine.cc b/paddle/fluid/jit/engine/interpreter_engine.cc index 23cb3ee8b5a20..9c5f7b20d9fd6 100644 --- a/paddle/fluid/jit/engine/interpreter_engine.cc +++ b/paddle/fluid/jit/engine/interpreter_engine.cc @@ -20,9 +20,9 @@ #include "paddle/fluid/framework/ir/pass.h" #include "paddle/fluid/framework/new_executor/interpretercore.h" #include "paddle/fluid/framework/program_desc.h" -#include "paddle/ir/core/program.h" -#include "paddle/ir/core/value.h" #include "paddle/phi/core/enforce.h" +#include "paddle/pir/core/program.h" +#include "paddle/pir/core/value.h" namespace paddle { namespace jit { diff --git a/paddle/fluid/operators/cinn/cinn_launch_context.cc b/paddle/fluid/operators/cinn/cinn_launch_context.cc index fc23dbf88064c..0700028807fc0 100644 --- a/paddle/fluid/operators/cinn/cinn_launch_context.cc +++ b/paddle/fluid/operators/cinn/cinn_launch_context.cc @@ -42,9 +42,9 @@ #include "paddle/fluid/platform/device_context.h" #include "paddle/fluid/platform/place.h" #include "paddle/fluid/string/printf.h" -#include "paddle/ir/core/program.h" -#include "paddle/ir/core/value.h" #include "paddle/phi/core/ddim.h" +#include "paddle/pir/core/program.h" +#include "paddle/pir/core/value.h" #include "paddle/utils/string/string_helper.h" namespace paddle { diff --git a/paddle/fluid/operators/cinn/cinn_launch_op.h b/paddle/fluid/operators/cinn/cinn_launch_op.h index 2913da9bc5c39..02e70c549cfc2 100644 --- a/paddle/fluid/operators/cinn/cinn_launch_op.h +++ b/paddle/fluid/operators/cinn/cinn_launch_op.h @@ -29,9 +29,9 @@ #include "paddle/fluid/operators/cinn/cinn_launch_context.h" #include "paddle/fluid/operators/cinn/cinn_op_helper.h" #include "paddle/fluid/platform/profiler.h" -#include "paddle/ir/core/program.h" -#include "paddle/ir/core/value.h" #include "paddle/phi/core/flags.h" +#include "paddle/pir/core/program.h" +#include "paddle/pir/core/value.h" PHI_DECLARE_bool(enable_pe_launch_cinn); PHI_DECLARE_bool(enable_interpretercore_launch_cinn); diff --git a/paddle/fluid/operators/class_center_sample_op.cu b/paddle/fluid/operators/class_center_sample_op.cu index f63baadbde526..efac6332c6d29 100644 --- a/paddle/fluid/operators/class_center_sample_op.cu +++ b/paddle/fluid/operators/class_center_sample_op.cu @@ -30,6 +30,7 @@ namespace cub = hipcub; #include #include "paddle/phi/common/memory_utils.h" +#include "paddle/phi/core/distributed/comm_context_manager.h" #include "paddle/phi/core/enforce.h" #include "paddle/phi/core/tensor_utils.h" @@ -37,6 +38,9 @@ namespace cub = hipcub; #include "paddle/fluid/distributed/collective/process_group.h" #include "paddle/fluid/platform/collective_helper.h" #include "paddle/fluid/platform/device/gpu/nccl_helper.h" +#include "paddle/phi/core/distributed/nccl_comm_context.h" +#include "paddle/phi/core/flags.h" +PHI_DECLARE_bool(dynamic_static_unified_comm); #endif #include "paddle/phi/backends/gpu/gpu_context.h" #include "paddle/phi/core/kernel_registry.h" @@ -364,21 +368,47 @@ void ClassCenterSampleKernel(const Context& dev_ctx, auto task = pg->AllReduce(in_tensor, out_tensor, opts); task->Wait(); } else { - const auto& comm = paddle::platform::NCCLCommContext::Instance().Get( - ring_id, dev_ctx.GetPlace()); + paddle::platform::NCCLComm* comm = nullptr; + phi::distributed::NCCLCommContext* comm_ctx = nullptr; // use global calculate stream - const auto calcu_stream = + auto stream = static_cast( phi::DeviceContextPool::Instance().Get(dev_ctx.GetPlace())) ->stream(); - PADDLE_ENFORCE_GPU_SUCCESS(phi::dynload::ncclAllReduce( - num_classes_per_device_ptr, - num_classes_per_device_ptr, - num_classes_per_device.numel(), - phi::ToNCCLDataType(num_classes_per_device.dtype()), - ncclSum, - comm->comm(), - calcu_stream)); + const auto& comm_context_manager = + phi::distributed::CommContextManager::GetInstance(); + if (FLAGS_dynamic_static_unified_comm) { + PADDLE_ENFORCE_EQ(comm_context_manager.Has(std::to_string(ring_id)), + true, + errors::InvalidArgument( + "You choose to use new communication library by " + "setting environment " + "variable FLAGS_dynamic_static_unified_comm " + "True. But ring_id(%d) is " + "not found in comm_context_manager.", + std::to_string(ring_id))); + comm_ctx = static_cast( + comm_context_manager.Get(std::to_string(ring_id))); + stream = comm_ctx->GetStream(); + } else { + comm = paddle::platform::NCCLCommContext::Instance().Get( + ring_id, dev_ctx.GetPlace()); + } + + if (comm_ctx) { + comm_ctx->AllReduce( + &num_classes_per_device, num_classes_per_device, ncclSum, stream); + paddle::platform::GpuStreamSync(stream); + } else { + PADDLE_ENFORCE_GPU_SUCCESS(phi::dynload::ncclAllReduce( + num_classes_per_device_ptr, + num_classes_per_device_ptr, + num_classes_per_device.numel(), + phi::ToNCCLDataType(num_classes_per_device.dtype()), + ncclSum, + comm->comm(), + stream)); + } } } #endif diff --git a/paddle/fluid/operators/collective/c_softmax_with_cross_entropy_op.cu b/paddle/fluid/operators/collective/c_softmax_with_cross_entropy_op.cu index 3c70b997a7fd8..344dcd36e5235 100644 --- a/paddle/fluid/operators/collective/c_softmax_with_cross_entropy_op.cu +++ b/paddle/fluid/operators/collective/c_softmax_with_cross_entropy_op.cu @@ -13,6 +13,7 @@ See the License for the specific language governing permissions and limitations under the License. */ #include "paddle/fluid/operators/collective/c_softmax_with_cross_entropy_op.h" +#include "paddle/phi/core/distributed/comm_context_manager.h" #include "paddle/phi/kernels/reduce_sum_kernel.h" #include "paddle/fluid/framework/eigen.h" @@ -26,6 +27,12 @@ limitations under the License. */ #include "paddle/phi/kernels/funcs/math_function.h" #include "paddle/phi/kernels/funcs/softmax_impl.h" +#if defined(PADDLE_WITH_NCCL) || defined(PADDLE_WITH_RCCL) +#include "paddle/phi/core/distributed/nccl_comm_context.h" +#include "paddle/phi/core/flags.h" +PHI_DECLARE_bool(dynamic_static_unified_comm); +#endif + namespace paddle { namespace operators { @@ -136,13 +143,41 @@ struct CSoftmaxWithCrossEntropyFunctor { const int rank = ctx.Attr("rank"); const auto& place = ctx.GetPlace(); - const auto& comm = platform::NCCLCommContext::Instance().Get(rid, place); auto& dev_ctx = ctx.template device_context(); - // use global calculate stream - const auto stream = static_cast( - platform::DeviceContextPool::Instance().Get(place)) - ->stream(); + gpuStream_t stream = nullptr; + platform::NCCLComm* comm = nullptr; + phi::distributed::NCCLCommContext* comm_ctx = nullptr; + + const auto& comm_context_manager = + phi::distributed::CommContextManager::GetInstance(); + + if (FLAGS_dynamic_static_unified_comm) { + PADDLE_ENFORCE_EQ(comm_context_manager.Has(std::to_string(rid)), + true, + platform::errors::InvalidArgument( + "You choose to use new communication library by " + "setting environment " + "variable FLAGS_dynamic_static_unified_comm True. " + "But ring_id(%d) is " + "not found in comm_context_manager.", + std::to_string(rid))); + comm_ctx = static_cast( + comm_context_manager.Get(std::to_string(rid))); + PADDLE_ENFORCE_NE(comm_ctx, + nullptr, + platform::errors::Unavailable( + "NCCLCommContext is nullptr, collective op should " + "has ring_id attr.")); + + stream = comm_ctx->GetStream(); + VLOG(3) << "new comm_context_manager has ring_id " << rid; + } else { // old comm_context + comm = platform::NCCLCommContext::Instance().Get(rid, place); + + stream = comm->stream(); + VLOG(3) << "old NCCLCommContext has ring_id " << rid; + } // allocate memory on device. softmax->mutable_data(place); @@ -166,21 +201,27 @@ struct CSoftmaxWithCrossEntropyFunctor { // step 1, obtain logit_max phi::DenseTensor logits_max; logits_max = ctx.AllocateTmpTensor({N, 1}, dev_ctx); - void* logits_max_buff = logits_max.mutable_data(place); auto eigen_logits_max = phi::funcs::EigenMatrix::From(logits_max); Eigen::DSizes along_axis(1); eigen_logits_max.device(*dev_ctx.eigen_device()) = eigen_logits.maximum(along_axis); - PADDLE_ENFORCE_GPU_SUCCESS(platform::dynload::ncclAllReduce( - logits_max_buff, - logits_max_buff, - logits_max.numel(), - platform::ToNCCLDataType( - framework::TransToProtoVarType(logits_max.dtype())), - ncclMax, - comm->comm(), - stream)); + + if (comm_ctx) { + comm_ctx->AllReduce(&logits_max, logits_max, ncclMax, stream); + } else { + void* logits_max_buff = logits_max.mutable_data(place); + + PADDLE_ENFORCE_GPU_SUCCESS(platform::dynload::ncclAllReduce( + logits_max_buff, + logits_max_buff, + logits_max.numel(), + platform::ToNCCLDataType( + framework::TransToProtoVarType(logits_max.dtype())), + ncclMax, + comm->comm(), + stream)); + } // step 2, obtain logit - logit_max Eigen::DSizes batch_by_one(N, 1); @@ -230,39 +271,47 @@ struct CSoftmaxWithCrossEntropyFunctor { nranks); } - void* predict_logits_buff = predicted_logits.mutable_data(place); - PADDLE_ENFORCE_GPU_SUCCESS(platform::dynload::ncclAllReduce( - predict_logits_buff, - predict_logits_buff, - predicted_logits.numel(), - platform::ToNCCLDataType( - framework::TransToProtoVarType(predicted_logits.dtype())), - ncclSum, - comm->comm(), - stream)); - - // step 4, obtain exp(logit) + predicted_logits.mutable_data(place); + if (comm_ctx) { + comm_ctx->AllReduce(&predicted_logits, predicted_logits, ncclSum, stream); + } else { + void* predict_logits_buff = predicted_logits.data(); + PADDLE_ENFORCE_GPU_SUCCESS(platform::dynload::ncclAllReduce( + predict_logits_buff, + predict_logits_buff, + predicted_logits.numel(), + platform::ToNCCLDataType( + framework::TransToProtoVarType(predicted_logits.dtype())), + ncclSum, + comm->comm(), + stream)); + } eigen_softmax.device(*dev_ctx.eigen_device()) = eigen_softmax.exp(); // step 5, obtain sum_exp_logits phi::DenseTensor sum_exp_logits; sum_exp_logits = ctx.AllocateTmpTensor({N, 1}, dev_ctx); - void* sum_exp_logits_buff = sum_exp_logits.mutable_data(place); + sum_exp_logits.mutable_data(place); auto eigen_sum_exp_logits = phi::funcs::EigenMatrix::From(sum_exp_logits); eigen_sum_exp_logits.device(*dev_ctx.eigen_device()) = eigen_softmax.sum(along_axis); - PADDLE_ENFORCE_GPU_SUCCESS(platform::dynload::ncclAllReduce( - sum_exp_logits_buff, - sum_exp_logits_buff, - sum_exp_logits.numel(), - platform::ToNCCLDataType( - framework::TransToProtoVarType(sum_exp_logits.dtype())), - ncclSum, - comm->comm(), - stream)); + if (comm_ctx) { + comm_ctx->AllReduce(&sum_exp_logits, sum_exp_logits, ncclSum, stream); + } else { + void* sum_exp_logits_buff = sum_exp_logits.data(); + PADDLE_ENFORCE_GPU_SUCCESS(platform::dynload::ncclAllReduce( + sum_exp_logits_buff, + sum_exp_logits_buff, + sum_exp_logits.numel(), + platform::ToNCCLDataType( + framework::TransToProtoVarType(sum_exp_logits.dtype())), + ncclSum, + comm->comm(), + stream)); + } if (label_type == framework::proto::VarType::INT32) { CaculateLoss diff --git a/paddle/fluid/operators/collective/global_scatter_op.cu.cc b/paddle/fluid/operators/collective/global_scatter_op.cu.cc index 3136ac21ab764..45d91dc724108 100644 --- a/paddle/fluid/operators/collective/global_scatter_op.cu.cc +++ b/paddle/fluid/operators/collective/global_scatter_op.cu.cc @@ -13,12 +13,18 @@ See the License for the specific language governing permissions and limitations under the License. */ #include "paddle/fluid/operators/collective/global_scatter_op.h" +#include "paddle/phi/core/distributed/comm_context_manager.h" + +#include "paddle/fluid/distributed/collective/utils.h" +#include "paddle/fluid/framework/convert_utils.h" #if defined(PADDLE_WITH_NCCL) || defined(PADDLE_WITH_RCCL) #include "paddle/fluid/platform/collective_helper.h" #include "paddle/fluid/platform/device/gpu/nccl_helper.h" +#include "paddle/phi/core/distributed/nccl_comm_context.h" +#include "paddle/phi/core/flags.h" +PHI_DECLARE_bool(dynamic_static_unified_comm); #endif -#include "paddle/fluid/framework/convert_utils.h" namespace paddle { namespace operators { @@ -78,15 +84,48 @@ struct GlobalScatterFunctor { ring_id)); auto place = ctx.GetPlace(); - auto comm = platform::NCCLCommContext::Instance().Get(ring_id, place); gpuStream_t stream = nullptr; + platform::NCCLComm* comm = nullptr; + phi::distributed::NCCLCommContext* comm_ctx = nullptr; + int nranks = 0; + + const auto& comm_context_manager = + phi::distributed::CommContextManager::GetInstance(); + + if (FLAGS_dynamic_static_unified_comm) { + PADDLE_ENFORCE_EQ(comm_context_manager.Has(std::to_string(ring_id)), + true, + platform::errors::InvalidArgument( + "You choose to use new communication library by " + "setting environment " + "variable FLAGS_dynamic_static_unified_comm True. " + "But ring_id(%d) is " + "not found in comm_context_manager.", + std::to_string(ring_id))); + comm_ctx = static_cast( + comm_context_manager.Get(std::to_string(ring_id))); + PADDLE_ENFORCE_NE(comm_ctx, + nullptr, + platform::errors::Unavailable( + "NCCLCommContext is nullptr, collective op should " + "has ring_id attr.")); + + stream = comm_ctx->GetStream(); + nranks = comm_ctx->GetSize(); + VLOG(3) << "new comm_context_manager has ring_id " << ring_id; + } else { // old comm_context + comm = platform::NCCLCommContext::Instance().Get(ring_id, place); + + stream = comm->stream(); + nranks = comm->nranks(); + VLOG(3) << "old NCCLCommContext has ring_id " << ring_id; + } + if (ctx.Attr("use_calc_stream")) { // should ExecutionContext for calc stream. stream = ctx.cuda_device_context().stream(); - } else { - stream = comm->stream(); } - int nranks = comm->nranks(); + auto in_feat = x->dims()[1]; auto n_expert = local_count->dims()[0] / nranks; int64_t fwd_count = 0; @@ -103,34 +142,62 @@ struct GlobalScatterFunctor { } auto recv_ptr = 0; - auto send_buf = x->data(); - auto recv_buf = out->mutable_data(out_dims, place); + out->mutable_data(out_dims, place); - for (auto i = 0; i < n_expert; ++i) { - PADDLE_ENFORCE_GPU_SUCCESS(platform::dynload::ncclGroupStart()); - for (auto j = 0; j < nranks; ++j) { - int idx = i + j * n_expert; - if (cpu_local_count_data[idx]) { - PADDLE_ENFORCE_GPU_SUCCESS( - platform::dynload::ncclSend(send_buf + expert_ptr[idx] * in_feat, - cpu_local_count_data[idx] * in_feat, - dtype, - j, - comm->comm(), - stream)); + if (comm_ctx) { + for (auto i = 0; i < n_expert; ++i) { + comm_ctx->GroupStart(); + for (auto j = 0; j < nranks; ++j) { + int idx = i + j * n_expert; + if (cpu_local_count_data[idx]) { + auto send_buf = distributed::GetPartialTensor( + *x, + expert_ptr[idx] * in_feat, + cpu_local_count_data[idx] * in_feat); + + comm_ctx->Send( + send_buf, cpu_local_count_data[idx] * in_feat, j, stream); + } + if (cpu_global_count_data[idx]) { + auto recv_buf = distributed::GetPartialTensor( + *out, recv_ptr * in_feat, cpu_global_count_data[idx] * in_feat); + comm_ctx->Recv( + &recv_buf, cpu_global_count_data[idx] * in_feat, j, stream); + recv_ptr += cpu_global_count_data[idx]; + } } - if (cpu_global_count_data[idx]) { - PADDLE_ENFORCE_GPU_SUCCESS( - platform::dynload::ncclRecv(recv_buf + recv_ptr * in_feat, - cpu_global_count_data[idx] * in_feat, - dtype, - j, - comm->comm(), - stream)); - recv_ptr += cpu_global_count_data[idx]; + comm_ctx->GroupEnd(); + } + } else { + auto send_buf = x->data(); + auto recv_buf = out->data(); + + for (auto i = 0; i < n_expert; ++i) { + PADDLE_ENFORCE_GPU_SUCCESS(platform::dynload::ncclGroupStart()); + for (auto j = 0; j < nranks; ++j) { + int idx = i + j * n_expert; + if (cpu_local_count_data[idx]) { + PADDLE_ENFORCE_GPU_SUCCESS(platform::dynload::ncclSend( + send_buf + expert_ptr[idx] * in_feat, + cpu_local_count_data[idx] * in_feat, + dtype, + j, + comm->comm(), + stream)); + } + if (cpu_global_count_data[idx]) { + PADDLE_ENFORCE_GPU_SUCCESS(platform::dynload::ncclRecv( + recv_buf + recv_ptr * in_feat, + cpu_global_count_data[idx] * in_feat, + dtype, + j, + comm->comm(), + stream)); + recv_ptr += cpu_global_count_data[idx]; + } } + PADDLE_ENFORCE_GPU_SUCCESS(platform::dynload::ncclGroupEnd()); } - PADDLE_ENFORCE_GPU_SUCCESS(platform::dynload::ncclGroupEnd()); } #else diff --git a/paddle/fluid/operators/collective/partial_allgather_op.cu.cc b/paddle/fluid/operators/collective/partial_allgather_op.cu.cc index d22fd70bd0f61..cf353c12ffa49 100644 --- a/paddle/fluid/operators/collective/partial_allgather_op.cu.cc +++ b/paddle/fluid/operators/collective/partial_allgather_op.cu.cc @@ -18,8 +18,14 @@ limitations under the License. */ #include "paddle/fluid/distributed/collective/process_group.h" #include "paddle/fluid/platform/collective_helper.h" #include "paddle/fluid/platform/device/gpu/nccl_helper.h" +#include "paddle/phi/core/distributed/nccl_comm_context.h" +#include "paddle/phi/core/flags.h" +PHI_DECLARE_bool(dynamic_static_unified_comm); #endif +#include "paddle/fluid/distributed/collective/utils.h" +#include "paddle/phi/core/distributed/comm_context_manager.h" + namespace paddle { namespace operators { @@ -38,17 +44,57 @@ class PartialAllGatherOpCUDAKernel : public framework::OpKernel { int rank = ctx.Attr("rank"); int rid = ctx.Attr("ring_id"); auto place = ctx.GetPlace(); - auto comm = platform::NCCLCommContext::Instance().Get(rid, place); + gpuStream_t stream = nullptr; + + platform::NCCLComm* comm = nullptr; + phi::distributed::NCCLCommContext* comm_ctx = nullptr; + + const auto& comm_context_manager = + phi::distributed::CommContextManager::GetInstance(); + + int real_nranks = 0; + int real_rank = 0; + if (FLAGS_dynamic_static_unified_comm) { + PADDLE_ENFORCE_EQ(comm_context_manager.Has(std::to_string(rid)), + true, + platform::errors::InvalidArgument( + "You choose to use new communication library by " + "setting environment " + "variable FLAGS_dynamic_static_unified_comm True. " + "But ring_id(%d) is " + "not found in comm_context_manager.", + std::to_string(rid))); + comm_ctx = static_cast( + comm_context_manager.Get(std::to_string(rid))); + PADDLE_ENFORCE_NE(comm_ctx, + nullptr, + platform::errors::Unavailable( + "NCCLCommContext is nullptr, collective op should " + "has ring_id attr.")); + + stream = comm_ctx->GetStream(); + real_nranks = comm_ctx->GetSize(); + real_rank = comm_ctx->GetRank(); + VLOG(3) << "new comm_context_manager has ring_id " << rid; + } else { // old comm_context + comm = platform::NCCLCommContext::Instance().Get(rid, place); + + stream = comm->stream(); + real_nranks = comm->nranks(); + real_rank = comm->rank(); + VLOG(3) << "old NCCLCommContext has ring_id " << rid; + } PADDLE_ENFORCE_EQ( nranks, - comm->nranks(), + real_nranks, platform::errors::InvalidArgument( - "nranks: %s should equal to %s", nranks, comm->nranks())); + "nranks: %s should equal to %s", nranks, real_nranks)); PADDLE_ENFORCE_EQ(rank, - comm->rank(), + real_rank, platform::errors::InvalidArgument( - "rank: %s should equal to %s", rank, comm->rank())); + "rank: %s should equal to %s", rank, real_rank)); + PADDLE_ENFORCE_EQ( (numel % nranks), 0, @@ -70,24 +116,26 @@ class PartialAllGatherOpCUDAKernel : public framework::OpKernel { auto task = pg->AllGather(out, *in, offset, send_numel, /*sync_op*/ true); task->Wait(); } else { - const T* send_buff = in->data() + offset; - T* recv_buff = out->data(); - - gpuStream_t stream = nullptr; if (ctx.Attr("use_calc_stream")) { // should ExecutionContext for calc stream. stream = ctx.cuda_device_context().stream(); - } else { - stream = comm->stream(); } - PADDLE_ENFORCE_GPU_SUCCESS( - platform::dynload::ncclAllGather(send_buff, - recv_buff, - send_numel, - static_cast(dtype), - comm->comm(), - stream)); + if (comm_ctx) { + auto send_buf = distributed::GetPartialTensor(*in, offset, send_numel); + + comm_ctx->AllGather(out, send_buf, stream); + } else { + const T* send_buff = in->data() + offset; + T* recv_buff = out->data(); + PADDLE_ENFORCE_GPU_SUCCESS( + platform::dynload::ncclAllGather(send_buff, + recv_buff, + send_numel, + static_cast(dtype), + comm->comm(), + stream)); + } } #else PADDLE_THROW(platform::errors::PreconditionNotMet( diff --git a/paddle/fluid/operators/collective/partial_send_op.cu.cc b/paddle/fluid/operators/collective/partial_send_op.cu.cc index 4f9fc41bc4e16..67089a18c8e4f 100644 --- a/paddle/fluid/operators/collective/partial_send_op.cu.cc +++ b/paddle/fluid/operators/collective/partial_send_op.cu.cc @@ -18,8 +18,14 @@ limitations under the License. */ #include "paddle/fluid/distributed/collective/process_group.h" #include "paddle/fluid/platform/collective_helper.h" #include "paddle/fluid/platform/device/gpu/nccl_helper.h" +#include "paddle/phi/core/distributed/nccl_comm_context.h" +#include "paddle/phi/core/flags.h" +PHI_DECLARE_bool(dynamic_static_unified_comm); #endif + +#include "paddle/fluid/distributed/collective/utils.h" #include "paddle/fluid/framework/convert_utils.h" +#include "paddle/phi/core/distributed/comm_context_manager.h" namespace paddle { namespace operators { @@ -75,33 +81,82 @@ class PartialSendCUDAKernel : public framework::OpKernel { } else { gpuStream_t stream = nullptr; auto place = ctx.GetPlace(); - auto comm = platform::NCCLCommContext::Instance().Get(rid, place); + + platform::NCCLComm* comm = nullptr; + phi::distributed::NCCLCommContext* comm_ctx = nullptr; + int nranks = 0; + int rank = 0; + + const auto& comm_context_manager = + phi::distributed::CommContextManager::GetInstance(); + + if (FLAGS_dynamic_static_unified_comm) { + PADDLE_ENFORCE_EQ( + comm_context_manager.Has(std::to_string(rid)), + true, + platform::errors::InvalidArgument( + "You choose to use new communication library by " + "setting environment " + "variable FLAGS_dynamic_static_unified_comm True. " + "But ring_id(%d) is " + "not found in comm_context_manager.", + std::to_string(rid))); + comm_ctx = static_cast( + comm_context_manager.Get(std::to_string(rid))); + PADDLE_ENFORCE_NE( + comm_ctx, + nullptr, + platform::errors::Unavailable( + "NCCLCommContext is nullptr, collective op should " + "has ring_id attr.")); + + stream = comm_ctx->GetStream(); + nranks = comm_ctx->GetSize(); + rank = comm_ctx->GetRank(); + + VLOG(3) << "new comm_context_manager has ring_id " << rid; + } else { // old comm_context + comm = platform::NCCLCommContext::Instance().Get(rid, place); + + stream = comm->stream(); + nranks = comm->nranks(); + rank = comm->rank(); + + VLOG(3) << "old NCCLCommContext has ring_id " << rid; + } + if (ctx.Attr("use_calc_stream")) { // should ExecutionContext for calc stream. stream = ctx.cuda_device_context().stream(); - } else { - stream = comm->stream(); } + PADDLE_ENFORCE_LT(peer, - comm->nranks(), + nranks, platform::errors::InvalidArgument( "The value of peer (%d) you set must " - "be less than comm->nranks (%d).", + "be less than ranks (%d).", peer, - comm->nranks())); + nranks)); ncclDataType_t dtype = platform::ToNCCLDataType(framework::TransToProtoVarType(x->dtype())); - PADDLE_ENFORCE_GPU_SUCCESS( - platform::dynload::ncclSend(x->data() + offset, - send_numel, - dtype, - peer, - comm->comm(), - stream)); - VLOG(3) << "rank " << comm->rank() << " send " << send_numel - << " from offset[" << offset << "] to " << peer; + if (comm_ctx) { + auto send_buf = distributed::GetPartialTensor(*x, offset, send_numel); + + comm_ctx->Send(send_buf, send_numel, peer, stream); + } else { + PADDLE_ENFORCE_GPU_SUCCESS( + platform::dynload::ncclSend(x->data() + offset, + send_numel, + dtype, + peer, + comm->comm(), + stream)); + } + + VLOG(3) << "rank " << rank << " send " << send_numel << " from offset[" + << offset << "] to " << peer; } #else PADDLE_THROW(platform::errors::Unavailable( diff --git a/paddle/fluid/operators/controlflow/CMakeLists.txt b/paddle/fluid/operators/controlflow/CMakeLists.txt index 03e6527a58a73..b08f26a75f68c 100644 --- a/paddle/fluid/operators/controlflow/CMakeLists.txt +++ b/paddle/fluid/operators/controlflow/CMakeLists.txt @@ -3,13 +3,17 @@ if(WITH_UNITY_BUILD) # Load Unity Build rules for operators in paddle/fluid/operators/controlflow. include(unity_build_rule.cmake) endif() -register_operators(EXCLUDES conditional_block_op DEPS naive_executor +register_operators(EXCLUDES conditional_block_op pylayer_op DEPS naive_executor standalone_executor) cc_library( conditional_block_op SRCS conditional_block_op.cc DEPS executor standalone_executor) +cc_library( + pylayer_op + SRCS pylayer_op.cc + DEPS standalone_executor) cc_library( op_variant SRCS op_variant.cc @@ -18,6 +22,10 @@ cc_library( conditional_block_op_helper SRCS conditional_block_op_helper.cc DEPS op_variant operator conditional_block_op) +cc_library( + pylayer_op_helper + SRCS pylayer_op_helper.cc + DEPS op_variant operator pylayer_op) cc_library( recurrent_op_helper SRCS recurrent_op_helper.cc @@ -28,7 +36,8 @@ cc_library( DEPS op_variant operator) if(WITH_UNITY_BUILD) - target_link_libraries(paddle_operators_controlflow_unity conditional_block_op) + target_link_libraries(paddle_operators_controlflow_unity conditional_block_op + pylayer_op) else() target_link_libraries(conditional_block_infer_op conditional_block_op) endif() diff --git a/paddle/fluid/operators/controlflow/pylayer_op.cc b/paddle/fluid/operators/controlflow/pylayer_op.cc index eef62289d76f5..fe05f47707445 100644 --- a/paddle/fluid/operators/controlflow/pylayer_op.cc +++ b/paddle/fluid/operators/controlflow/pylayer_op.cc @@ -51,14 +51,35 @@ void PyLayerOp::CreateInterpreter( dev_place, block, cur_scope, execution_config)); VLOG(10) << "[interpreterCore] created:" << core_; } else { - // NOTE: Borrowed from - // `paddle/fluid/operators/controlflow/control_flow_op_helper.h` - // TODO(MarioLulab): Add PyLayer Helper ? BuildScopeForControlFlowOp(*core_, block, cur_scope); core_->reset_scope(cur_scope); } } +class PyLayerForwardOpProtoMaker : public framework::OpProtoAndCheckerMaker { + public: + void Make() override { + AddInput(PyLayerOp::kInputs, "The input variables of the sub-block.") + .AsDuplicable(); + AddOutput(PyLayerOp::kOutputs, "The output variables of the sub-block.") + .AsDuplicable(); + AddOutput( + PyLayerOp::kScope, + "(std::vector) The scope of static pylayer block, used for " + "passing intermediate variables between forward and backward."); + AddAttr>( + "blocks", + "The blocks of PyLayer operator where blocks[0] indicates the forward " + "block and blocks[1] indicates the backward block."); + AddComment(R"DOC(PyLayer operator + +The PyLayer Operator is designed to support `@to_static` for `PyLayer in Dynamic Graph`. + + +)DOC"); + } +}; + class PyLayerForwardOp : public PyLayerOp { public: PyLayerForwardOp(const std::string &type, @@ -109,7 +130,7 @@ class PyLayerForwardOp : public PyLayerOp { class PyLayerForwardInferShape : public framework::InferShapeBase { public: void operator()(framework::InferShapeContext *context) const override { - // TODO(MarioLulab): do nothing. + // NOTE(MarioLulab): do nothing. } }; diff --git a/paddle/fluid/operators/controlflow/pylayer_op.h b/paddle/fluid/operators/controlflow/pylayer_op.h index e06daad78041d..afbb2fd151a40 100644 --- a/paddle/fluid/operators/controlflow/pylayer_op.h +++ b/paddle/fluid/operators/controlflow/pylayer_op.h @@ -49,27 +49,5 @@ class PyLayerOp : public framework::OperatorBase { protected: mutable std::shared_ptr core_{nullptr}; }; - -class PyLayerForwardOpProtoMaker : public framework::OpProtoAndCheckerMaker { - public: - void Make() override { - AddInput(PyLayerOp::kInputs, "The input variables of the sub-block.") - .AsDuplicable(); - AddOutput(PyLayerOp::kOutputs, "The output variables of the sub-block.") - .AsDuplicable(); - // TODO(MarioLulab): Must Use std::vector here ? - AddOutput(PyLayerOp::kScope, - "(std::vector) The scope of static pylayer block."); - AddAttr>( - "blocks", "The blocks of PyLayer operator"); - AddComment(R"DOC(PyLayer operator - -TO-DO: added by luqi - - -)DOC"); - } -}; - } // namespace operators } // namespace paddle diff --git a/paddle/fluid/operators/controlflow/pylayer_op_helper.cc b/paddle/fluid/operators/controlflow/pylayer_op_helper.cc new file mode 100644 index 0000000000000..dabe561eea3e7 --- /dev/null +++ b/paddle/fluid/operators/controlflow/pylayer_op_helper.cc @@ -0,0 +1,177 @@ +// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/fluid/operators/controlflow/pylayer_op_helper.h" + +#include + +namespace paddle { +namespace framework { +class ProgramDesc; +} // namespace framework +} // namespace paddle + +namespace paddle { +namespace operators { + +static bool IsMatchedPyLayerOpAndPyLayerGradOp(const OpVariant &fwd_op, + const OpVariant &bwd_op) { + return fwd_op.Outputs().at(PyLayerOp::kScope) == + bwd_op.Inputs().at(PyLayerOp::kScope); +} + +static void FindAllPyLayerOpAndPyLayerGradOp( + const framework::ProgramDesc &program, + std::vector *fwd_ops, + std::vector *bwd_ops) { + PADDLE_ENFORCE_GE( + fwd_ops->size(), + bwd_ops->size(), + platform::errors::InvalidArgument( + "Size of forward ops must be greater or equal to backward ops. The " + "number of forward ops is %d and the number of backward ops is %d", + fwd_ops->size(), + bwd_ops->size())); + + for (size_t i = 1; i < program.Size(); ++i) { + auto &block = program.Block(i); + for (size_t j = 0; j < block.OpSize(); ++j) { + auto *op = block.Op(j); + if (op->Type() == "pylayer") { + fwd_ops->emplace_back(op); + } else if (op->Type() == "pylayer_grad") { + bwd_ops->emplace_back(op); + } + } + } + + PADDLE_ENFORCE_GE( + fwd_ops->size(), + bwd_ops->size(), + platform::errors::InvalidArgument( + "There are more pylayer_grad ops than " + "pylayer ops in the graph or program. The number of " + "forward ops is %d and the number of backward ops is %d", + fwd_ops->size(), + bwd_ops->size())); +} + +static void SetSkipVarsForPyLayerOp(OpVariant *fwd_op, OpVariant *bwd_op) { + auto *grad_block = bwd_op->Attr("backward_block"); + auto is_skippable_in_fwd = [grad_block](const std::string &var_name) { + return var_name != framework::kEmptyVarName && + !grad_block->HasVar(var_name); + }; + + std::unordered_set forward_skip_vars; + for (auto *op_desc : grad_block->AllOps()) { + for (auto &in_arg_name : op_desc->InputArgumentNames()) { + if (is_skippable_in_fwd(in_arg_name)) { + forward_skip_vars.insert(in_arg_name); + } + } + + for (auto &out_arg_name : op_desc->OutputArgumentNames()) { + if (is_skippable_in_fwd(out_arg_name)) { + forward_skip_vars.insert(out_arg_name); + } + } + } + + auto &fwd_attrs = const_cast(fwd_op->Attrs()); + std::vector skip_vars_vec(forward_skip_vars.begin(), + forward_skip_vars.end()); + VLOG(2) << "Prepare to skip " << skip_vars_vec.size() + << " var(s): " << string::join_strings(skip_vars_vec, ' '); + fwd_attrs[PyLayerOp::kSkipEagerDeletionVars] = std::move(skip_vars_vec); +} + +static void PrepareSafeEagerDeletionOnPyLayerOpAndPyLayerGradOp( + const framework::ProgramDesc &program, + std::vector *pylayer_ops, + std::vector *pylayer_grad_ops) { + FindAllPyLayerOpAndPyLayerGradOp(program, pylayer_ops, pylayer_grad_ops); + + VLOG(2) << "Found pylayer op num: " << pylayer_ops->size() + << ", pylayer_grad op num: " << pylayer_grad_ops->size(); + + if (pylayer_grad_ops->empty()) { + return; + } + + std::unordered_set pylayer_op_set( + pylayer_ops->begin(), pylayer_ops->end()); + + for (auto &bwd_op : *pylayer_grad_ops) { + const OpVariant *matched_fwd_op = nullptr; + for (auto &fwd_op : pylayer_op_set) { + if (IsMatchedPyLayerOpAndPyLayerGradOp(fwd_op, bwd_op)) { + PADDLE_ENFORCE_EQ(matched_fwd_op, + nullptr, + platform::errors::PreconditionNotMet( + "Found multiple matched pylayer ops.")); + matched_fwd_op = &fwd_op; + } + } + + PADDLE_ENFORCE_NOT_NULL(matched_fwd_op, + platform::errors::PreconditionNotMet( + "Cannot find matched forward pylayer op.")); + + SetSkipVarsForPyLayerOp(const_cast(matched_fwd_op), &bwd_op); + pylayer_op_set.erase(*matched_fwd_op); + } +} + +void PrepareSafeEagerDeletionOnPyLayerOpAndPyLayerGradOp( + const framework::ProgramDesc &program, + int block_id, + const std::vector> &all_ops) { + // If block_id is not 0, returns + // This is because all pylayer_ops and pylayer_grad_ops + // in the whole program would be processed when block_id is 0 (i.e. + // when Executor::Run() or ParallelExecutor constructs). + + // What's more, all pylayer_ops and pylayer_grad_ops + // must be processed when block_id is zero. If not, pylayer_op + // may run first and erase variables used in pylayer_grad_op, + // and in this moment, pylayer_grad_ops may be not constructed yet. + if (block_id != 0) return; + + std::vector fwd_ops, bwd_ops; + for (auto &op : all_ops) { + if (op->Type() == "pylayer") { + fwd_ops.emplace_back(op.get()); + } else if (op->Type() == "pylayer_grad") { + bwd_ops.emplace_back(op.get()); + } + } + + PrepareSafeEagerDeletionOnPyLayerOpAndPyLayerGradOp( + program, &fwd_ops, &bwd_ops); +} + +void PrepareSafeEagerDeletionOnPyLayerOpAndPyLayerGradOp( + const framework::ProgramDesc &program, + const std::vector &pylayer_ops, + const std::vector &pylayer_grad_ops) { + std::vector fwd_ops = pylayer_ops; + std::vector bwd_ops = pylayer_grad_ops; + + PrepareSafeEagerDeletionOnPyLayerOpAndPyLayerGradOp( + program, &fwd_ops, &bwd_ops); +} + +} // namespace operators +} // namespace paddle diff --git a/paddle/fluid/operators/controlflow/pylayer_op_helper.h b/paddle/fluid/operators/controlflow/pylayer_op_helper.h new file mode 100644 index 0000000000000..1295a6cba60a0 --- /dev/null +++ b/paddle/fluid/operators/controlflow/pylayer_op_helper.h @@ -0,0 +1,45 @@ +// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include +#include + +#include "paddle/fluid/framework/operator.h" +#include "paddle/fluid/operators/controlflow/op_variant.h" +#include "paddle/fluid/operators/controlflow/pylayer_op.h" +#include "paddle/fluid/string/string_helper.h" + +namespace paddle { +namespace framework { +class ProgramDesc; +} // namespace framework +} // namespace paddle + +namespace paddle { +namespace operators { + +void PrepareSafeEagerDeletionOnPyLayerOpAndPyLayerGradOp( + const framework::ProgramDesc &program, + int block_id, + const std::vector> &all_ops); + +void PrepareSafeEagerDeletionOnPyLayerOpAndPyLayerGradOp( + const framework::ProgramDesc &program, + const std::vector &pylayer_ops, + const std::vector &pylayer_grad_ops); + +} // namespace operators +} // namespace paddle diff --git a/paddle/fluid/operators/fused/fused_elemwise_activation_op.cc b/paddle/fluid/operators/fused/fused_elemwise_activation_op.cc index 6c949045ef212..778e6ed277fd7 100644 --- a/paddle/fluid/operators/fused/fused_elemwise_activation_op.cc +++ b/paddle/fluid/operators/fused/fused_elemwise_activation_op.cc @@ -242,14 +242,14 @@ For case 2 (assume that the shape of $Y$ is a continuous subsequence of $X$ ): For example: - .. code-block:: python - - shape(X) = (2, 3, 4, 5), shape(Y) = (,) - shape(X) = (2, 3, 4, 5), shape(Y) = (5,) - shape(X) = (2, 3, 4, 5), shape(Y) = (4, 5), with axis=-1(default) or axis=2 - shape(X) = (2, 3, 4, 5), shape(Y) = (3, 4), with axis=1 - shape(X) = (2, 3, 4, 5), shape(Y) = (2), with axis=0 - shape(X) = (2, 3, 4, 5), shape(Y) = (2, 1), with axis=0 + .. code-block:: text + + shape(X) = (2, 3, 4, 5), shape(Y) = (,) + shape(X) = (2, 3, 4, 5), shape(Y) = (5,) + shape(X) = (2, 3, 4, 5), shape(Y) = (4, 5), with axis=-1(default) or axis=2 + shape(X) = (2, 3, 4, 5), shape(Y) = (3, 4), with axis=1 + shape(X) = (2, 3, 4, 5), shape(Y) = (2), with axis=0 + shape(X) = (2, 3, 4, 5), shape(Y) = (2, 1), with axis=0 The inputs $X$ and $Y$ can carry the different LoD information. diff --git a/paddle/fluid/operators/fused/scaled_dp_attention.h b/paddle/fluid/operators/fused/scaled_dp_attention.h index c0f3d7fee0f30..016c3995d7383 100644 --- a/paddle/fluid/operators/fused/scaled_dp_attention.h +++ b/paddle/fluid/operators/fused/scaled_dp_attention.h @@ -227,7 +227,6 @@ void softmax_sum_max(float* AB, float refac, int m, int k) { - assert(k % 16 == 0); float max_val = std::numeric_limits::lowest(); __m512 vrefac = _mm512_set1_ps(refac); for (int i = 0; i < m; ++i) { @@ -290,7 +289,6 @@ void update_out_blk(float* output, float* max, int m, int n) { - assert(n % 16 == 0); for (int i = 0; i < m; ++i) { const float* buf = exp_ABC + i * n; float* outbuf = output + i * n; @@ -298,10 +296,12 @@ void update_out_blk(float* output, merr = vexp(merr); __m512 vfac = _mm512_set1_ps(pre_sum[i] / sum[i]); for (int off = 0; off < n; off += 16) { - __m512 vout = _mm512_loadu_ps(outbuf + off); - __m512 vabc = _mm512_loadu_ps(buf + off); + int remain = n - off; + __mmask16 mask = (remain >= 16 ? 0xffff : (1 << remain) - 1); + __m512 vout = _mm512_maskz_loadu_ps(mask, outbuf + off); + __m512 vabc = _mm512_maskz_loadu_ps(mask, buf + off); __m512 vupt = vout * merr * vfac + vabc; - _mm512_storeu_ps(outbuf + off, vupt); + _mm512_mask_storeu_ps(outbuf + off, mask, vupt); } pre_sum[i] = sum[i]; pre_max[i] = max[i]; @@ -348,8 +348,6 @@ void scaled_dp_attention(const float* query, int iblk = std::min(512, itsize / 1); int oblk = std::min(512, otsize / 1); float refac = scale; - assert(itsize % iblk == 0); - assert(otsize % oblk == 0); #ifdef PADDLE_WITH_MKLML int nth = omp_get_max_threads(); diff --git a/paddle/fluid/operators/fused/self_dp_attention_op.cc b/paddle/fluid/operators/fused/self_dp_attention_op.cc index 04c7424a80dc5..bf0f59865c8ab 100644 --- a/paddle/fluid/operators/fused/self_dp_attention_op.cc +++ b/paddle/fluid/operators/fused/self_dp_attention_op.cc @@ -30,13 +30,6 @@ void SelfDPAttenOp::InferShape(framework::InferShapeContext* ctx) const { "[batchsize, tokensize, 3, nhead, headsize] " ", but now Input X dim is:[%s] ", dim_input)); - PADDLE_ENFORCE_EQ(dim_input[4] % 16, - 0, - platform::errors::InvalidArgument( - "The last dim of input X should be a multiple of 16, " - ", but now the dim is:[%d] " - "Please remove self_attention_fuse_pass from the lists", - dim_input[4])); framework::DDim out_dims( {dim_input[0], dim_input[1], dim_input[3], dim_input[4]}); ctx->SetOutputDim("Out", out_dims); diff --git a/paddle/fluid/operators/generator/CMakeLists.txt b/paddle/fluid/operators/generator/CMakeLists.txt index af346e402bf83..dc88ea0b3a533 100644 --- a/paddle/fluid/operators/generator/CMakeLists.txt +++ b/paddle/fluid/operators/generator/CMakeLists.txt @@ -25,7 +25,7 @@ function(install_py_pyyaml) execute_process( COMMAND ${PYTHON_EXECUTABLE} "-c" - "import re, pyyaml; print(re.compile('/__init__.py.*').sub('',pyyaml.__file__))" + "import re, yaml; print(re.compile('/__init__.py.*').sub('',yaml.__file__))" RESULT_VARIABLE _pyyaml_status ERROR_QUIET OUTPUT_STRIP_TRAILING_WHITESPACE) diff --git a/paddle/fluid/operators/matmul_op.cc b/paddle/fluid/operators/matmul_op.cc index 5c381e31673ef..69c64de705645 100644 --- a/paddle/fluid/operators/matmul_op.cc +++ b/paddle/fluid/operators/matmul_op.cc @@ -16,6 +16,7 @@ limitations under the License. */ #include "paddle/fluid/framework/op_registry.h" #include "paddle/fluid/framework/op_version_registry.h" #include "paddle/phi/kernels/funcs/blas/blas.h" +#include "paddle/phi/kernels/impl/matmul_kernel_impl.h" namespace paddle { namespace operators { @@ -53,61 +54,90 @@ static framework::DDim ColumnMatrixFromVector(const framework::DDim &y_dim) { return phi::make_ddim({y_dim[0], 1}); } -template -class MatMulKernel : public framework::OpKernel { - public: - void Compute(const framework::ExecutionContext &context) const override { - auto &x = GET_DATA_SAFELY( - context.Input("X"), "Input", "X", "MatMul"); - auto &y = GET_DATA_SAFELY( - context.Input("Y"), "Input", "Y", "MatMul"); - auto *out = context.Output("Out"); - - auto &dev_ctx = context.template device_context(); - dev_ctx.template Alloc(out, out->numel() * sizeof(T)); - - auto blas = phi::funcs::GetBlas(dev_ctx); - auto mat_dim_a = phi::funcs::CreateMatrixDescriptor( - RowMatrixFromVector(x.dims()), 0, context.Attr("transpose_X")); - auto mat_dim_b = phi::funcs::CreateMatrixDescriptor( - ColumnMatrixFromVector(y.dims()), 0, context.Attr("transpose_Y")); - auto scale = static_cast(context.Attr("alpha")); +#if defined(PADDLE_WITH_CUDA) && CUDA_VERSION >= 11060 +template +typename std::enable_if::value, void>::type +ComputeMatmulImpl(const framework::ExecutionContext &context) { + auto &dev_ctx = context.template device_context(); + + auto &x = GET_DATA_SAFELY( + context.Input("X"), "Input", "X", "MatMul"); + auto &y = GET_DATA_SAFELY( + context.Input("Y"), "Input", "Y", "MatMul"); + auto *out = context.Output("Out"); + + dev_ctx.template Alloc(out, out->numel() * sizeof(T)); + + phi::MatmulKernel(dev_ctx, + x, + y, + context.Attr("transpose_X"), + context.Attr("transpose_Y"), + out); +} +#endif - int head_number = 1; +template +typename std::enable_if::value, void>::type +ComputeMatmulImpl(const framework::ExecutionContext &context) { + auto &x = GET_DATA_SAFELY( + context.Input("X"), "Input", "X", "MatMul"); + auto &y = GET_DATA_SAFELY( + context.Input("Y"), "Input", "Y", "MatMul"); + auto *out = context.Output("Out"); + + auto &dev_ctx = context.template device_context(); + dev_ctx.template Alloc(out, out->numel() * sizeof(T)); + + auto blas = phi::funcs::GetBlas(dev_ctx); + auto mat_dim_a = phi::funcs::CreateMatrixDescriptor( + RowMatrixFromVector(x.dims()), 0, context.Attr("transpose_X")); + auto mat_dim_b = phi::funcs::CreateMatrixDescriptor( + ColumnMatrixFromVector(y.dims()), 0, context.Attr("transpose_Y")); + auto scale = static_cast(context.Attr("alpha")); + + int head_number = 1; #if defined(PADDLE_WITH_MKLML) && !defined(PADDLE_WITH_CUDA) && \ !defined(PADDLE_WITH_HIP) - head_number = context.Attr("head_number"); + head_number = context.Attr("head_number"); #endif - const auto &x_dims = x.dims(); - const auto &y_dims = y.dims(); - if (head_number <= 1 && x_dims.size() == 3 && y_dims.size() <= 2) { - // the transpose_X must be false, if is true, the transpose cost much time - if (!context.Attr("transpose_X")) { - mat_dim_a.height_ *= mat_dim_a.batch_size_; - mat_dim_a.batch_size_ = 0; - } + const auto &x_dims = x.dims(); + const auto &y_dims = y.dims(); + if (head_number <= 1 && x_dims.size() == 3 && y_dims.size() <= 2) { + // the transpose_X must be false, if is true, the transpose cost much time + if (!context.Attr("transpose_X")) { + mat_dim_a.height_ *= mat_dim_a.batch_size_; + mat_dim_a.batch_size_ = 0; } + } #if defined(PADDLE_WITH_MKLML) && !defined(PADDLE_WITH_CUDA) && \ !defined(PADDLE_WITH_HIP) - bool split_vertical_y = (mat_dim_a.width_ != mat_dim_b.height_); - - if (head_number > 1) { - blas.MatMulWithHead(x, - mat_dim_a, - y, - mat_dim_b, - scale, - head_number, - out, - T(0), - split_vertical_y); - } else { - blas.MatMul(x, mat_dim_a, y, mat_dim_b, scale, out, T(0)); - } -#else + bool split_vertical_y = (mat_dim_a.width_ != mat_dim_b.height_); + + if (head_number > 1) { + blas.MatMulWithHead(x, + mat_dim_a, + y, + mat_dim_b, + scale, + head_number, + out, + T(0), + split_vertical_y); + } else { blas.MatMul(x, mat_dim_a, y, mat_dim_b, scale, out, T(0)); + } +#else + blas.MatMul(x, mat_dim_a, y, mat_dim_b, scale, out, T(0)); #endif +} + +template +class MatMulKernel : public framework::OpKernel { + public: + void Compute(const framework::ExecutionContext &context) const override { + ComputeMatmulImpl(context); } }; @@ -926,12 +956,31 @@ REGISTER_OP_CPU_KERNEL(matmul_grad_grad, ops::MatMulDoubleGradKernel, ops::MatMulDoubleGradKernel); -#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP) +#if defined(PADDLE_WITH_HIP) REGISTER_OP_CUDA_KERNEL( matmul, ops::MatMulKernel, ops::MatMulKernel, ops::MatMulKernel); +#endif + +#if defined(PADDLE_WITH_CUDA) +#if CUDA_VERSION >= 11060 +REGISTER_OP_CUDA_KERNEL( + matmul, + ops::MatMulKernel, + ops::MatMulKernel, + ops::MatMulKernel, + ops::MatMulKernel); +#else +REGISTER_OP_CUDA_KERNEL( + matmul, + ops::MatMulKernel, + ops::MatMulKernel, + ops::MatMulKernel); +#endif +#endif + REGISTER_OP_CUDA_KERNEL( matmul_grad, ops::MatMulGradKernel, @@ -940,7 +989,6 @@ REGISTER_OP_CUDA_KERNEL( REGISTER_OP_CUDA_KERNEL(matmul_grad_grad, ops::MatMulDoubleGradKernel, ops::MatMulDoubleGradKernel); -#endif REGISTER_OP_VERSION(matmul).AddCheckpoint( R"ROC(Register matmul for adding the attribute of diff --git a/paddle/fluid/ir/CMakeLists.txt b/paddle/fluid/pir/CMakeLists.txt similarity index 100% rename from paddle/fluid/ir/CMakeLists.txt rename to paddle/fluid/pir/CMakeLists.txt diff --git a/paddle/fluid/pir/dialect/CMakeLists.txt b/paddle/fluid/pir/dialect/CMakeLists.txt new file mode 100644 index 0000000000000..17a73237c5fdb --- /dev/null +++ b/paddle/fluid/pir/dialect/CMakeLists.txt @@ -0,0 +1,2 @@ +add_subdirectory(operator) +add_subdirectory(kernel) diff --git a/paddle/fluid/ir/dialect/paddle_kernel_dialect/CMakeLists.txt b/paddle/fluid/pir/dialect/kernel/CMakeLists.txt similarity index 100% rename from paddle/fluid/ir/dialect/paddle_kernel_dialect/CMakeLists.txt rename to paddle/fluid/pir/dialect/kernel/CMakeLists.txt diff --git a/paddle/fluid/ir/dialect/paddle_kernel_dialect/ir/CMakeLists.txt b/paddle/fluid/pir/dialect/kernel/ir/CMakeLists.txt similarity index 80% rename from paddle/fluid/ir/dialect/paddle_kernel_dialect/ir/CMakeLists.txt rename to paddle/fluid/pir/dialect/kernel/ir/CMakeLists.txt index af5e5c4fc9016..bdfdb75410524 100644 --- a/paddle/fluid/ir/dialect/paddle_kernel_dialect/ir/CMakeLists.txt +++ b/paddle/fluid/pir/dialect/kernel/ir/CMakeLists.txt @@ -2,4 +2,4 @@ file(GLOB PADDLE_KERNEL_DIALECT_SRCS "*.cc") cc_library( pd_kernel_dialect SRCS ${PADDLE_KERNEL_DIALECT_SRCS} - DEPS pd_dialect_core) + DEPS pd_op_dialect_core) diff --git a/paddle/fluid/ir/dialect/paddle_kernel_dialect/ir/kernel_attribute_storage.h b/paddle/fluid/pir/dialect/kernel/ir/attribute_storage.h similarity index 88% rename from paddle/fluid/ir/dialect/paddle_kernel_dialect/ir/kernel_attribute_storage.h rename to paddle/fluid/pir/dialect/kernel/ir/attribute_storage.h index 18312b88b8ae2..1c8b4f9150b25 100644 --- a/paddle/fluid/ir/dialect/paddle_kernel_dialect/ir/kernel_attribute_storage.h +++ b/paddle/fluid/pir/dialect/kernel/ir/attribute_storage.h @@ -14,16 +14,16 @@ #pragma once -#include "paddle/ir/core/attribute.h" -#include "paddle/ir/core/attribute_base.h" -#include "paddle/ir/core/utils.h" #include "paddle/phi/common/data_type.h" #include "paddle/phi/core/kernel_factory.h" +#include "paddle/pir/core/attribute.h" +#include "paddle/pir/core/attribute_base.h" +#include "paddle/pir/core/utils.h" namespace paddle { namespace dialect { -struct KernelAttributeStorage : public ir::AttributeStorage { +struct KernelAttributeStorage : public pir::AttributeStorage { using ParamKey = phi::KernelKey; explicit KernelAttributeStorage(const ParamKey &key) { kernel_key_ = key; } diff --git a/paddle/fluid/ir/dialect/paddle_kernel_dialect/ir/kernel_attribute.cc b/paddle/fluid/pir/dialect/kernel/ir/kernel_attribute.cc similarity index 89% rename from paddle/fluid/ir/dialect/paddle_kernel_dialect/ir/kernel_attribute.cc rename to paddle/fluid/pir/dialect/kernel/ir/kernel_attribute.cc index 43ed52ffc6701..f8c23f993ca2d 100644 --- a/paddle/fluid/ir/dialect/paddle_kernel_dialect/ir/kernel_attribute.cc +++ b/paddle/fluid/pir/dialect/kernel/ir/kernel_attribute.cc @@ -12,6 +12,6 @@ // See the License for the specific language governing permissions and // limitations under the License. -#include "paddle/fluid/ir/dialect/paddle_kernel_dialect/ir/kernel_attribute.h" +#include "paddle/fluid/pir/dialect/kernel/ir/kernel_attribute.h" IR_DEFINE_EXPLICIT_TYPE_ID(paddle::dialect::KernelAttribute) diff --git a/paddle/fluid/ir/dialect/paddle_kernel_dialect/ir/kernel_attribute.h b/paddle/fluid/pir/dialect/kernel/ir/kernel_attribute.h similarity index 86% rename from paddle/fluid/ir/dialect/paddle_kernel_dialect/ir/kernel_attribute.h rename to paddle/fluid/pir/dialect/kernel/ir/kernel_attribute.h index fa17b823f0278..7b6bc2336813a 100644 --- a/paddle/fluid/ir/dialect/paddle_kernel_dialect/ir/kernel_attribute.h +++ b/paddle/fluid/pir/dialect/kernel/ir/kernel_attribute.h @@ -14,14 +14,14 @@ #pragma once -#include "paddle/fluid/ir/dialect/paddle_kernel_dialect/ir/kernel_attribute_storage.h" -#include "paddle/ir/core/attribute.h" +#include "paddle/fluid/pir/dialect/kernel/ir/attribute_storage.h" #include "paddle/phi/core/enforce.h" +#include "paddle/pir/core/attribute.h" namespace paddle { namespace dialect { -class KernelAttribute : public ir::Attribute { +class KernelAttribute : public pir::Attribute { public: using Attribute::Attribute; diff --git a/paddle/fluid/ir/dialect/paddle_kernel_dialect/ir/kernel_dialect.cc b/paddle/fluid/pir/dialect/kernel/ir/kernel_dialect.cc similarity index 76% rename from paddle/fluid/ir/dialect/paddle_kernel_dialect/ir/kernel_dialect.cc rename to paddle/fluid/pir/dialect/kernel/ir/kernel_dialect.cc index c2f4dfefb4d2b..592319dcfd36e 100644 --- a/paddle/fluid/ir/dialect/paddle_kernel_dialect/ir/kernel_dialect.cc +++ b/paddle/fluid/pir/dialect/kernel/ir/kernel_dialect.cc @@ -12,26 +12,26 @@ // See the License for the specific language governing permissions and // limitations under the License. -#include "paddle/fluid/ir/dialect/paddle_kernel_dialect/ir/kernel_dialect.h" -#include "paddle/fluid/ir/dialect/paddle_kernel_dialect/ir/kernel_attribute.h" -#include "paddle/fluid/ir/dialect/paddle_kernel_dialect/ir/kernel_op.h" -#include "paddle/fluid/ir/dialect/paddle_kernel_dialect/ir/kernel_type.h" +#include "paddle/fluid/pir/dialect/kernel/ir/kernel_dialect.h" +#include "paddle/fluid/pir/dialect/kernel/ir/kernel_attribute.h" +#include "paddle/fluid/pir/dialect/kernel/ir/kernel_op.h" +#include "paddle/fluid/pir/dialect/kernel/ir/kernel_type.h" #include "paddle/fluid/platform/init_phi.h" -#include "paddle/ir/core/ir_printer.h" #include "paddle/phi/common/place.h" #include "paddle/phi/core/ddim.h" +#include "paddle/pir/core/ir_printer.h" REGISTER_FILE_SYMBOLS(kernel_dialect); namespace paddle { namespace dialect { -PaddleKernelDialect::PaddleKernelDialect(ir::IrContext *context) - : ir::Dialect(name(), context, ir::TypeId::get()) { +KernelDialect::KernelDialect(pir::IrContext *context) + : pir::Dialect(name(), context, pir::TypeId::get()) { initialize(); } -void PaddleKernelDialect::initialize() { +void KernelDialect::initialize() { RegisterTypes(); RegisterTypes(); @@ -39,7 +39,7 @@ void PaddleKernelDialect::initialize() { RegisterAttributes(); } -void PaddleKernelDialect::PrintType(ir::Type type, std::ostream &os) const { +void KernelDialect::PrintType(pir::Type type, std::ostream &os) const { if (type.isa()) { AllocatedDenseTensorType tensor_type = type.dyn_cast(); @@ -67,16 +67,16 @@ void PaddleKernelDialect::PrintType(ir::Type type, std::ostream &os) const { } } -void PaddleKernelDialect::PrintAttribute(ir::Attribute attr, - std::ostream &os) const { +void KernelDialect::PrintAttribute(pir::Attribute attr, + std::ostream &os) const { phi::KernelKey kernel = attr.dyn_cast().data(); os << ""; } -void PaddleKernelDialect::PrintOperation(ir::Operation *op, - ir::IrPrinter &printer) const { +void KernelDialect::PrintOperation(pir::Operation *op, + pir::IrPrinter &printer) const { if (op->dyn_cast() || op->dyn_cast()) { auto &os = printer.os; printer.PrintOpResult(op); @@ -86,7 +86,7 @@ void PaddleKernelDialect::PrintOperation(ir::Operation *op, if (op->attributes().count("is_inplace") != 0 && op->attributes() .at("is_inplace") - .dyn_cast() + .dyn_cast() .data()) { kernel_name = kernel_name + "_"; } @@ -97,7 +97,7 @@ void PaddleKernelDialect::PrintOperation(ir::Operation *op, if (op->attributes().count("is_inplace") != 0 && op->attributes() .at("is_inplace") - .dyn_cast() + .dyn_cast() .data()) { kernel_name = kernel_name + "_"; } @@ -117,4 +117,4 @@ void PaddleKernelDialect::PrintOperation(ir::Operation *op, } // namespace dialect } // namespace paddle -IR_DEFINE_EXPLICIT_TYPE_ID(paddle::dialect::PaddleKernelDialect) +IR_DEFINE_EXPLICIT_TYPE_ID(paddle::dialect::KernelDialect) diff --git a/paddle/fluid/ir/dialect/paddle_kernel_dialect/ir/kernel_dialect.h b/paddle/fluid/pir/dialect/kernel/ir/kernel_dialect.h similarity index 64% rename from paddle/fluid/ir/dialect/paddle_kernel_dialect/ir/kernel_dialect.h rename to paddle/fluid/pir/dialect/kernel/ir/kernel_dialect.h index 8099e1d1da093..d2fbcadaf8cf2 100644 --- a/paddle/fluid/ir/dialect/paddle_kernel_dialect/ir/kernel_dialect.h +++ b/paddle/fluid/pir/dialect/kernel/ir/kernel_dialect.h @@ -14,23 +14,23 @@ #pragma once -#include "paddle/ir/core/dialect.h" +#include "paddle/pir/core/dialect.h" namespace paddle { namespace dialect { -class PaddleKernelDialect : public ir::Dialect { +class KernelDialect : public pir::Dialect { public: - explicit PaddleKernelDialect(ir::IrContext* context); + explicit KernelDialect(pir::IrContext* context); static const char* name() { return "pd_kernel"; } - void PrintType(ir::Type type, std::ostream& os) const override; + void PrintType(pir::Type type, std::ostream& os) const override; - void PrintAttribute(ir::Attribute attr, std::ostream& os) const override; + void PrintAttribute(pir::Attribute attr, std::ostream& os) const override; - void PrintOperation(ir::Operation* op, - ir::IrPrinter& printer) const override; // NOLINT + void PrintOperation(pir::Operation* op, + pir::IrPrinter& printer) const override; // NOLINT private: void initialize(); @@ -39,4 +39,4 @@ class PaddleKernelDialect : public ir::Dialect { } // namespace dialect } // namespace paddle -IR_DECLARE_EXPLICIT_TYPE_ID(paddle::dialect::PaddleKernelDialect) +IR_DECLARE_EXPLICIT_TYPE_ID(paddle::dialect::KernelDialect) diff --git a/paddle/fluid/ir/dialect/paddle_kernel_dialect/ir/kernel_op.cc b/paddle/fluid/pir/dialect/kernel/ir/kernel_op.cc similarity index 78% rename from paddle/fluid/ir/dialect/paddle_kernel_dialect/ir/kernel_op.cc rename to paddle/fluid/pir/dialect/kernel/ir/kernel_op.cc index 4a934505aad55..62c1129f84620 100644 --- a/paddle/fluid/ir/dialect/paddle_kernel_dialect/ir/kernel_op.cc +++ b/paddle/fluid/pir/dialect/kernel/ir/kernel_op.cc @@ -12,10 +12,10 @@ // See the License for the specific language governing permissions and // limitations under the License. -#include "paddle/fluid/ir/dialect/paddle_kernel_dialect/ir/kernel_op.h" -#include "paddle/fluid/ir/dialect/paddle_kernel_dialect/ir/kernel_attribute.h" -#include "paddle/ir/core/builtin_attribute.h" +#include "paddle/fluid/pir/dialect/kernel/ir/kernel_op.h" +#include "paddle/fluid/pir/dialect/kernel/ir/kernel_attribute.h" #include "paddle/phi/core/enforce.h" +#include "paddle/pir/core/builtin_attribute.h" namespace paddle { namespace dialect { @@ -31,12 +31,12 @@ void PhiKernelOp::Verify() { auto& attributes = this->attributes(); PADDLE_ENFORCE(attributes.count("op_name") > 0 && - attributes.at("op_name").isa(), + attributes.at("op_name").isa(), phi::errors::PreconditionNotMet( "Type of attribute: op_name is not right.")); PADDLE_ENFORCE(attributes.count("kernel_name") > 0 && - attributes.at("kernel_name").isa(), + attributes.at("kernel_name").isa(), phi::errors::PreconditionNotMet( "Type of attribute: kernel_name is not right.")); @@ -47,10 +47,13 @@ void PhiKernelOp::Verify() { } std::string PhiKernelOp::op_name() { - return attributes().at("op_name").dyn_cast().AsString(); + return attributes().at("op_name").dyn_cast().AsString(); } std::string PhiKernelOp::kernel_name() { - return attributes().at("kernel_name").dyn_cast().AsString(); + return attributes() + .at("kernel_name") + .dyn_cast() + .AsString(); } phi::KernelKey PhiKernelOp::kernel_key() { return attributes().at("kernel_key").dyn_cast().data(); @@ -67,12 +70,12 @@ void LegacyKernelOp::Verify() { auto& attributes = this->attributes(); PADDLE_ENFORCE(attributes.count("op_name") > 0 && - attributes.at("op_name").isa(), + attributes.at("op_name").isa(), phi::errors::PreconditionNotMet( "Type of attribute: op_name is not right.")); PADDLE_ENFORCE(attributes.count("kernel_name") > 0 && - attributes.at("kernel_name").isa(), + attributes.at("kernel_name").isa(), phi::errors::PreconditionNotMet( "Type of attribute: kernel_name is not right.")); @@ -83,10 +86,13 @@ void LegacyKernelOp::Verify() { } std::string LegacyKernelOp::op_name() { - return attributes().at("op_name").dyn_cast().AsString(); + return attributes().at("op_name").dyn_cast().AsString(); } std::string LegacyKernelOp::kernel_name() { - return attributes().at("kernel_name").dyn_cast().AsString(); + return attributes() + .at("kernel_name") + .dyn_cast() + .AsString(); } phi::KernelKey LegacyKernelOp::kernel_key() { return attributes().at("kernel_key").dyn_cast().data(); diff --git a/paddle/fluid/ir/dialect/paddle_kernel_dialect/ir/kernel_op.h b/paddle/fluid/pir/dialect/kernel/ir/kernel_op.h similarity index 89% rename from paddle/fluid/ir/dialect/paddle_kernel_dialect/ir/kernel_op.h rename to paddle/fluid/pir/dialect/kernel/ir/kernel_op.h index 0a574bc60b218..8a18959665e0c 100644 --- a/paddle/fluid/ir/dialect/paddle_kernel_dialect/ir/kernel_op.h +++ b/paddle/fluid/pir/dialect/kernel/ir/kernel_op.h @@ -14,13 +14,13 @@ #pragma once -#include "paddle/ir/core/builder.h" -#include "paddle/ir/core/op_base.h" #include "paddle/phi/core/kernel_factory.h" +#include "paddle/pir/core/builder.h" +#include "paddle/pir/core/op_base.h" namespace paddle { namespace dialect { -class PhiKernelOp : public ir::Op { +class PhiKernelOp : public pir::Op { public: using Op::Op; static const char *name() { return "pd_kernel.phi_kernel"; } @@ -32,7 +32,7 @@ class PhiKernelOp : public ir::Op { void Verify(); }; -class LegacyKernelOp : public ir::Op { +class LegacyKernelOp : public pir::Op { public: using Op::Op; static const char *name() { return "pd_kernel.legacy_kernel"; } diff --git a/paddle/fluid/ir/dialect/paddle_kernel_dialect/ir/kernel_type.cc b/paddle/fluid/pir/dialect/kernel/ir/kernel_type.cc similarity index 91% rename from paddle/fluid/ir/dialect/paddle_kernel_dialect/ir/kernel_type.cc rename to paddle/fluid/pir/dialect/kernel/ir/kernel_type.cc index 9740f1296a51b..60a722f13dab5 100644 --- a/paddle/fluid/ir/dialect/paddle_kernel_dialect/ir/kernel_type.cc +++ b/paddle/fluid/pir/dialect/kernel/ir/kernel_type.cc @@ -12,7 +12,7 @@ // See the License for the specific language governing permissions and // limitations under the License. -#include "paddle/fluid/ir/dialect/paddle_kernel_dialect/ir/kernel_type.h" +#include "paddle/fluid/pir/dialect/kernel/ir/kernel_type.h" namespace paddle { namespace dialect { @@ -21,7 +21,7 @@ const phi::Place& AllocatedDenseTensorType::place() const { return storage()->place_; } -const ir::Type& AllocatedDenseTensorType::dtype() const { +const pir::Type& AllocatedDenseTensorType::dtype() const { return storage()->dense_tensor_type_.dtype(); } @@ -45,7 +45,7 @@ const phi::Place& AllocatedSelectedRowsType::place() const { return storage()->place_; } -const ir::Type& AllocatedSelectedRowsType::dtype() const { +const pir::Type& AllocatedSelectedRowsType::dtype() const { return storage()->selected_rows_type_.dtype(); } diff --git a/paddle/fluid/ir/dialect/paddle_kernel_dialect/ir/kernel_type.h b/paddle/fluid/pir/dialect/kernel/ir/kernel_type.h similarity index 66% rename from paddle/fluid/ir/dialect/paddle_kernel_dialect/ir/kernel_type.h rename to paddle/fluid/pir/dialect/kernel/ir/kernel_type.h index b00f2e5320dde..adb78639d65c0 100644 --- a/paddle/fluid/ir/dialect/paddle_kernel_dialect/ir/kernel_type.h +++ b/paddle/fluid/pir/dialect/kernel/ir/kernel_type.h @@ -14,30 +14,30 @@ #pragma once -#include "paddle/fluid/ir/dialect/paddle_dialect/ir/pd_type.h" -#include "paddle/fluid/ir/dialect/paddle_kernel_dialect/ir/kernel_type_storage.h" -#include "paddle/ir/core/type.h" +#include "paddle/fluid/pir/dialect/kernel/ir/type_storage.h" +#include "paddle/fluid/pir/dialect/operator/ir/op_type.h" +#include "paddle/pir/core/type.h" namespace paddle { namespace dialect { -class AllocatedDenseTensorType : public ir::Type { +class AllocatedDenseTensorType + : public pir::Type::TypeBase { public: - using Type::Type; + using Base::Base; - DECLARE_TYPE_UTILITY_FUNCTOR(AllocatedDenseTensorType, - AllocatedDenseTensorTypeStorage); - - static AllocatedDenseTensorType get(ir::IrContext *ctx, + static AllocatedDenseTensorType get(pir::IrContext *ctx, const phi::Place &place, dialect::DenseTensorType type) { - return ir::TypeManager::template get( + return pir::TypeManager::template get( ctx, place, type); } - static AllocatedDenseTensorType get(ir::IrContext *ctx, + static AllocatedDenseTensorType get(pir::IrContext *ctx, const phi::Place &place, - const ir::Type &dtype, + const pir::Type &dtype, const phi::DDim &dims, const phi::DataLayout &layout, const phi::LoD &lod, @@ -45,13 +45,13 @@ class AllocatedDenseTensorType : public ir::Type { dialect::DenseTensorType dense_tensor_type = dialect::DenseTensorType::get(ctx, dtype, dims, layout, lod, offset); - return ir::TypeManager::template get( + return pir::TypeManager::template get( ctx, place, dense_tensor_type); } const phi::Place &place() const; - const ir::Type &dtype() const; + const pir::Type &dtype() const; const phi::DDim &dims() const; @@ -62,23 +62,23 @@ class AllocatedDenseTensorType : public ir::Type { const size_t &offset() const; }; -class AllocatedSelectedRowsType : public ir::Type { +class AllocatedSelectedRowsType + : public pir::Type::TypeBase { public: - using Type::Type; - - DECLARE_TYPE_UTILITY_FUNCTOR(AllocatedSelectedRowsType, - AllocatedSelectedRowsTypeStorage); + using Base::Base; - static AllocatedSelectedRowsType get(ir::IrContext *ctx, + static AllocatedSelectedRowsType get(pir::IrContext *ctx, const phi::Place &place, dialect::SelectedRowsType type) { - return ir::TypeManager::template get( + return pir::TypeManager::template get( ctx, place, type); } - static AllocatedSelectedRowsType get(ir::IrContext *ctx, + static AllocatedSelectedRowsType get(pir::IrContext *ctx, const phi::Place &place, - const ir::Type &dtype, + const pir::Type &dtype, const phi::DDim &dims, const phi::DataLayout &layout, const phi::LoD &lod, @@ -86,13 +86,13 @@ class AllocatedSelectedRowsType : public ir::Type { dialect::SelectedRowsType type = dialect::SelectedRowsType::get(ctx, dtype, dims, layout, lod, offset); - return ir::TypeManager::template get( + return pir::TypeManager::template get( ctx, place, type); } const phi::Place &place() const; - const ir::Type &dtype() const; + const pir::Type &dtype() const; const phi::DDim &dims() const; diff --git a/paddle/fluid/ir/dialect/paddle_kernel_dialect/ir/kernel_type_storage.h b/paddle/fluid/pir/dialect/kernel/ir/type_storage.h similarity index 72% rename from paddle/fluid/ir/dialect/paddle_kernel_dialect/ir/kernel_type_storage.h rename to paddle/fluid/pir/dialect/kernel/ir/type_storage.h index 1913dd6e6346c..46622587e51f5 100644 --- a/paddle/fluid/ir/dialect/paddle_kernel_dialect/ir/kernel_type_storage.h +++ b/paddle/fluid/pir/dialect/kernel/ir/type_storage.h @@ -16,10 +16,10 @@ #include -#include "paddle/fluid/ir/dialect/paddle_dialect/ir/pd_type.h" -#include "paddle/ir/core/type.h" -#include "paddle/ir/core/utils.h" +#include "paddle/fluid/pir/dialect/operator/ir/op_type.h" #include "paddle/phi/core/tensor_meta.h" +#include "paddle/pir/core/type.h" +#include "paddle/pir/core/utils.h" namespace paddle { namespace dialect { @@ -30,7 +30,7 @@ namespace dialect { /// following methods: (1)declare ParamKey, (2)define Construction method, /// (3)define HashValue method, (4)overload operator==. /// -struct AllocatedDenseTensorTypeStorage : public ir::TypeStorage { +struct AllocatedDenseTensorTypeStorage : public pir::TypeStorage { using Place = phi::Place; /// /// \brief Declare ParamKey according to parameter type. @@ -56,18 +56,19 @@ struct AllocatedDenseTensorTypeStorage : public ir::TypeStorage { static std::size_t HashValue(const ParamKey& key) { std::size_t hash_value = 0; // hash place - hash_value = ir::hash_combine(hash_value, std::get<0>(key).HashValue()); + hash_value = pir::hash_combine(hash_value, std::get<0>(key).HashValue()); // hash dtype auto dense_tensor_type = std::get<1>(key); - hash_value = ir::hash_combine(hash_value, - dialect::DenseTensorTypeStorage::HashValue( - dialect::DenseTensorTypeStorage::ParamKey( - dense_tensor_type.dtype(), - dense_tensor_type.dims(), - dense_tensor_type.data_layout(), - dense_tensor_type.lod(), - dense_tensor_type.offset()))); + hash_value = + pir::hash_combine(hash_value, + dialect::DenseTensorTypeStorage::HashValue( + dialect::DenseTensorTypeStorage::ParamKey( + dense_tensor_type.dtype(), + dense_tensor_type.dims(), + dense_tensor_type.data_layout(), + dense_tensor_type.lod(), + dense_tensor_type.offset()))); return hash_value; } @@ -92,7 +93,7 @@ struct AllocatedDenseTensorTypeStorage : public ir::TypeStorage { /// \brief Define Parametric TypeStorage for AllocatedSelectedRowsTypeStorage. /// /// -struct AllocatedSelectedRowsTypeStorage : public ir::TypeStorage { +struct AllocatedSelectedRowsTypeStorage : public pir::TypeStorage { using Place = phi::Place; /// /// \brief Declare ParamKey according to parameter type. @@ -118,18 +119,19 @@ struct AllocatedSelectedRowsTypeStorage : public ir::TypeStorage { static std::size_t HashValue(const ParamKey& key) { std::size_t hash_value = 791; // hash place - hash_value = ir::hash_combine(hash_value, std::get<0>(key).HashValue()); + hash_value = pir::hash_combine(hash_value, std::get<0>(key).HashValue()); // hash dtype auto selected_rows_type = std::get<1>(key); - hash_value = ir::hash_combine(hash_value, - dialect::DenseTensorTypeStorage::HashValue( - dialect::DenseTensorTypeStorage::ParamKey( - selected_rows_type.dtype(), - selected_rows_type.dims(), - selected_rows_type.data_layout(), - selected_rows_type.lod(), - selected_rows_type.offset()))); + hash_value = + pir::hash_combine(hash_value, + dialect::DenseTensorTypeStorage::HashValue( + dialect::DenseTensorTypeStorage::ParamKey( + selected_rows_type.dtype(), + selected_rows_type.dims(), + selected_rows_type.data_layout(), + selected_rows_type.lod(), + selected_rows_type.offset()))); return hash_value; } diff --git a/paddle/fluid/ir/dialect/op_generator/api_gen.py b/paddle/fluid/pir/dialect/op_generator/api_gen.py similarity index 71% rename from paddle/fluid/ir/dialect/op_generator/api_gen.py rename to paddle/fluid/pir/dialect/op_generator/api_gen.py index cae035c657b69..66f1af1ed69e7 100644 --- a/paddle/fluid/ir/dialect/op_generator/api_gen.py +++ b/paddle/fluid/pir/dialect/op_generator/api_gen.py @@ -17,7 +17,12 @@ import re import yaml -from op_gen import OpCompatParser, OpInfoParser, to_pascal_case +from op_gen import ( + PD_MANUAL_OP_LIST, + OpCompatParser, + OpInfoParser, + to_pascal_case, +) H_FILE_TEMPLATE = """ @@ -25,11 +30,11 @@ #include -#include "paddle/ir/core/value.h" +#include "paddle/pir/core/value.h" #include "paddle/phi/common/data_type.h" #include "paddle/phi/common/place.h" #include "paddle/phi/common/scalar.h" -#include "paddle/fluid/ir/dialect/paddle_dialect/ir/pd_manual_api.h" +#include "paddle/fluid/pir/dialect/operator/ir/manual_api.h" {body} @@ -37,11 +42,11 @@ CPP_FILE_TEMPLATE = """ -#include "paddle/fluid/ir/dialect/paddle_dialect/ir/pd_api.h" -#include "paddle/fluid/ir/dialect/paddle_dialect/ir/api_builder.h" -#include "paddle/fluid/ir/dialect/paddle_dialect/ir/pd_op.h" -#include "paddle/ir/core/builder.h" -#include "paddle/ir/core/builtin_op.h" +#include "paddle/fluid/pir/dialect/operator/ir/pd_api.h" +#include "paddle/fluid/pir/dialect/operator/ir/api_builder.h" +#include "paddle/fluid/pir/dialect/operator/ir/pd_op.h" +#include "paddle/pir/core/builder.h" +#include "paddle/pir/core/builtin_op.h" {body} @@ -71,17 +76,16 @@ """ COMBINE_OP_TEMPLATE = """ - auto {op_name} = APIBuilder::Instance().GetBuilder()->Build({in_name});""" + auto {op_name} = APIBuilder::Instance().GetBuilder()->Build({in_name});""" SPLIT_OP_TEMPLATE = """ - auto {op_name} = APIBuilder::Instance().GetBuilder()->Build({in_name});""" + auto {op_name} = APIBuilder::Instance().GetBuilder()->Build({in_name});""" COMPUTE_OP_TEMPLATE = """ paddle::dialect::{op_class_name} {op_inst_name} = APIBuilder::Instance().GetBuilder()->Build({args});""" -OP_RESULT = 'ir::OpResult' -VECTOR_TYPE = 'ir::VectorType' -PD_MANUAL_OP_LIST = ['add_n'] +OP_RESULT = 'pir::OpResult' +VECTOR_TYPE = 'pir::VectorType' def get_op_class_name(op_name): @@ -91,9 +95,9 @@ def get_op_class_name(op_name): class CodeGen: def __init__(self) -> None: self._type_map = { - 'paddle::dialect::DenseTensorType': 'ir::OpResult', - 'paddle::dialect::SelectedRowsType': 'ir::OpResult', - 'ir::VectorType': 'std::vector', + 'paddle::dialect::DenseTensorType': 'pir::OpResult', + 'paddle::dialect::SelectedRowsType': 'pir::OpResult', + 'pir::VectorType': 'std::vector', } def _parse_yaml(self, op_yaml_files, op_compat_yaml_file): @@ -111,6 +115,11 @@ def _parse_yaml(self, op_yaml_files, op_compat_yaml_file): ) return op_info_items + def _need_skip(self, op_info, op_name): + return ( + op_info.infer_meta_func is None and op_name not in PD_MANUAL_OP_LIST + ) + # ===================================== # Gen declare functions # ===================================== @@ -123,11 +132,14 @@ def _gen_api_inputs(self, op_info): ret.append(f'{self._type_map[type]} {name}') return ', '.join(ret) - def _gen_api_attrs(self, op_info, with_default, is_mutable_attr): + def _gen_api_attrs( + self, op_info, with_default, is_mutable_attr, is_vector_mutable_sttr + ): name_list = op_info.attribute_name_list type_list = op_info.attribute_build_arg_type_list default_value_list = op_info.attribute_default_value_list mutable_name_list = op_info.mutable_attribute_name_list + mutable_type_list = op_info.mutable_attribute_type_list assert len(name_list) == len(type_list) == len(default_value_list) no_mutable_attr = [] mutable_attr = [] @@ -135,7 +147,14 @@ def _gen_api_attrs(self, op_info, with_default, is_mutable_attr): name_list, type_list, default_value_list ): if is_mutable_attr and name in mutable_name_list: - mutable_attr.append(f'{OP_RESULT} {name}') + if ( + mutable_type_list[mutable_name_list.index(name)][0] + == "paddle::dialect::IntArrayAttribute" + and is_vector_mutable_sttr + ): + mutable_attr.append(f'std::vector<{OP_RESULT}> {name}') + else: + mutable_attr.append(f'{OP_RESULT} {name}') continue if with_default and default_value is not None: if type in ['float', 'double']: @@ -149,9 +168,17 @@ def _gen_api_attrs(self, op_info, with_default, is_mutable_attr): no_mutable_attr.append(f'{type} {name}') return ', '.join(mutable_attr + no_mutable_attr) - def _gen_api_args(self, op_info, with_default_attr, is_mutable_attr): + def _gen_api_args( + self, + op_info, + with_default_attr, + is_mutable_attr, + is_vector_mutable_attr, + ): inputs = self._gen_api_inputs(op_info) - attrs = self._gen_api_attrs(op_info, with_default_attr, is_mutable_attr) + attrs = self._gen_api_attrs( + op_info, with_default_attr, is_mutable_attr, is_vector_mutable_attr + ) return (inputs + ', ' + attrs).strip(', ') def _gen_ret_type(self, op_info): @@ -178,11 +205,15 @@ def _gen_ret_type(self, op_info): elif output_num == 0: return 'void' - def _gen_one_declare(self, op_info, op_name, is_mutable_attr): + def _gen_one_declare( + self, op_info, op_name, is_mutable_attr, is_vector_mutable_attr + ): return API_DECLARE_TEMPLATE.format( ret_type=self._gen_ret_type(op_info), api_name=op_name, - args=self._gen_api_args(op_info, True, is_mutable_attr), + args=self._gen_api_args( + op_info, True, is_mutable_attr, is_vector_mutable_attr + ), ) def _gen_h_file(self, op_info_items, namespaces, h_file_path): @@ -191,15 +222,21 @@ def _gen_h_file(self, op_info_items, namespaces, h_file_path): for op_name in op_info.op_phi_name: # NOTE:When infer_meta_func is None, the Build() function generated in pd_op # is wrong, so temporarily skip the automatic generation of these APIs - if ( - op_info.infer_meta_func is None - and op_name not in PD_MANUAL_OP_LIST - ): + if self._need_skip(op_info, op_name): continue - declare_str += self._gen_one_declare(op_info, op_name, False) + declare_str += self._gen_one_declare( + op_info, op_name, False, False + ) if len(op_info.mutable_attribute_name_list) > 0: - declare_str += self._gen_one_declare(op_info, op_name, True) - + declare_str += self._gen_one_declare( + op_info, op_name, True, False + ) + if "paddle::dialect::IntArrayAttribute" in { + type[0] for type in op_info.mutable_attribute_type_list + }: + declare_str += self._gen_one_declare( + op_info, op_name, True, True + ) body = declare_str for namespace in reversed(namespaces): body = NAMESPACE_TEMPLATE.format(namespace=namespace, body=body) @@ -209,7 +246,7 @@ def _gen_h_file(self, op_info_items, namespaces, h_file_path): # ===================================== # Gen impl functions # ===================================== - def _gen_in_combine(self, op_info): + def _gen_in_combine(self, op_info, is_mutable_attr, is_vector_mutable_attr): name_list = op_info.input_name_list type_list = op_info.input_type_list assert len(name_list) == len(type_list) @@ -224,6 +261,24 @@ def _gen_in_combine(self, op_info): combine_op_list.append(op_name) else: combine_op_list.append(None) + + if is_mutable_attr: + name_list = op_info.mutable_attribute_name_list + type_list = op_info.mutable_attribute_type_list + assert len(name_list) == len(type_list) + for name, type in zip(name_list, type_list): + if ( + type[0] == "paddle::dialect::IntArrayAttribute" + and is_vector_mutable_attr + ): + op_name = f'{name}_combine_op' + combine_op += COMBINE_OP_TEMPLATE.format( + op_name=op_name, in_name=name + ) + combine_op_list.append(op_name) + else: + combine_op_list.append(None) + return combine_op, combine_op_list def _gen_compute_op_args( @@ -233,15 +288,22 @@ def _gen_compute_op_args( all_attr_list = op_info.attribute_name_list no_mutable_attr_list = op_info.non_mutable_attribute_name_list mutable_attr_list = op_info.mutable_attribute_name_list - assert len(input_name_list) == len(in_combine_op_list) + assert len(input_name_list) + len(mutable_attr_list) == len( + in_combine_op_list + ) or len(input_name_list) == len(in_combine_op_list) ret = [] - for input_name, combine_op in zip(input_name_list, in_combine_op_list): + if is_mutable_attr: + name_list = input_name_list + mutable_attr_list + else: + name_list = input_name_list + + for input_name, combine_op in zip(name_list, in_combine_op_list): if combine_op is None: ret.append(input_name) else: ret.append(f'{combine_op}.out()') if is_mutable_attr: - ret += list(mutable_attr_list + no_mutable_attr_list) + ret += list(no_mutable_attr_list) else: ret += list(all_attr_list) return ', '.join(ret) @@ -293,9 +355,13 @@ def _gen_return_result(self, ret_list): elif len(ret_list) == 0: return 'return;' - def _gen_one_impl(self, op_info, op_name, is_mutable_attr): + def _gen_one_impl( + self, op_info, op_name, is_mutable_attr, is_vector_mutable_attr + ): ret_type = self._gen_ret_type(op_info) - in_combine, in_combine_op_list = self._gen_in_combine(op_info) + in_combine, in_combine_op_list = self._gen_in_combine( + op_info, is_mutable_attr, is_vector_mutable_attr + ) compute_op, op_inst_name = self._gen_compute_op( op_info, op_name, in_combine_op_list, is_mutable_attr ) @@ -309,7 +375,9 @@ def _gen_one_impl(self, op_info, op_name, is_mutable_attr): ret = API_IMPL_TEMPLATE.format( ret_type=ret_type, api_name=op_name, - args=self._gen_api_args(op_info, False, is_mutable_attr), + args=self._gen_api_args( + op_info, False, is_mutable_attr, is_vector_mutable_attr + ), in_combine=in_combine, compute_op=compute_op, out_split=out_split, @@ -325,14 +393,19 @@ def _gen_cpp_file(self, op_info_items, namespaces, cpp_file_path): for op_name in op_info.op_phi_name: # NOTE:When infer_meta_func is None, the Build() function generated in pd_op # is wrong, so temporarily skip the automatic generation of these APIs - if ( - op_info.infer_meta_func is None - and op_name not in PD_MANUAL_OP_LIST - ): + if self._need_skip(op_info, op_name): continue - impl_str += self._gen_one_impl(op_info, op_name, False) + impl_str += self._gen_one_impl(op_info, op_name, False, False) if len(op_info.mutable_attribute_name_list) > 0: - impl_str += self._gen_one_impl(op_info, op_name, True) + impl_str += self._gen_one_impl( + op_info, op_name, True, False + ) + if "paddle::dialect::IntArrayAttribute" in { + type[0] for type in op_info.mutable_attribute_type_list + }: + impl_str += self._gen_one_impl( + op_info, op_name, True, True + ) body = impl_str for namespace in reversed(namespaces): body = NAMESPACE_TEMPLATE.format(namespace=namespace, body=body) diff --git a/paddle/fluid/ir/dialect/op_generator/op_build_gen.py b/paddle/fluid/pir/dialect/op_generator/op_build_gen.py similarity index 76% rename from paddle/fluid/ir/dialect/op_generator/op_build_gen.py rename to paddle/fluid/pir/dialect/op_generator/op_build_gen.py index 66d1094c9e5fc..66a3d5fbdf311 100644 --- a/paddle/fluid/ir/dialect/op_generator/op_build_gen.py +++ b/paddle/fluid/pir/dialect/op_generator/op_build_gen.py @@ -13,9 +13,22 @@ # limitations under the License. # generator build function -_INFERMETA_NEED_META_CONFIG = {'SplitInferMeta'} +_INFERMETA_NEED_META_CONFIG = { + 'SplitInferMeta', + 'SumInferMeta', + 'SplitWithNumInferMeta', + 'ConcatInferMeta', + 'ReduceIntArrayAxisInferMeta', +} + +_PREPARE_DATA_WITH_UNKNOW_ATTRIBUTE = { + 'SplitOp', + 'SumOp', + 'SplitWithNumOp', + 'ConcatOp', + 'MeanOp', +} -_PREPARE_DATA_WITH_UNKNOW_ATTRIBUTE = {'SplitOp'} OP_BUILD_TEMPLATE = """ void {op_name}::Build({build_args}) {{ @@ -42,16 +55,16 @@ def GenBuildInputArgsStr( attr_args_is_map=False, ): ''' - Example: ir::Builder &builder, ir::OperationArgument &argument, ir::OpResult x_, phi::DataType dtype=phi::DataType::UNDEFINED, phi::Place place={} + Example: pir::Builder &builder, pir::OperationArgument &argument, pir::OpResult x_, phi::DataType dtype=phi::DataType::UNDEFINED, phi::Place place={} ''' # add inputs - build_args_str = "ir::Builder &builder, ir::OperationArgument &argument" + build_args_str = "pir::Builder &builder, pir::OperationArgument &argument" if len(op_input_name_list) > 0: for input_name in op_input_name_list: - build_args_str += ", ir::OpResult " + input_name + "_" + build_args_str += ", pir::OpResult " + input_name + "_" if attr_args_is_map: - build_args_str += ", ir::AttributeMap attributes" + build_args_str += ", pir::AttributeMap attributes" else: if not mutable_attr_is_input: # add attributes @@ -86,7 +99,7 @@ def GenBuildInputArgsStr( # add mutable attributes as inputs if len(op_mutable_attribute_name_list) > 0: for mutable_attr in op_mutable_attribute_name_list: - build_args_str += ", ir::OpResult " + mutable_attr + "_" + build_args_str += ", pir::OpResult " + mutable_attr + "_" # add non-mutable attributes for attr_idx in range(len(op_non_mutable_attribute_name_list)): @@ -146,11 +159,11 @@ def GenBuildInserFullForMutableAttribute( build_mutable_attribute = "" BUILD_INTARRAY_ATTRIBUTE_TEMPLATE = """ // Generate int_array mutable attribute: {attr_name} paddle::dialect::FullIntArrayOp full_{attr_name}_op = builder.Build({attr_name}, {phi_dtype}, phi::CPUPlace()); - ir::OpResult {attr_name}_ = full_{attr_name}_op->result(0); + pir::OpResult {attr_name}_ = full_{attr_name}_op->result(0); """ BUILD_SCALAR_ATTRIBUTE_TEMPLATE = """ // Generate scalar mutable attribute: {attr_name} paddle::dialect::FullOp full_{attr_name}_op = builder.Build(std::vector{{1}}, {attr_name}, {phi_dtype}, phi::CPUPlace()); - ir::OpResult {attr_name}_ = full_{attr_name}_op->result(0); + pir::OpResult {attr_name}_ = full_{attr_name}_op->result(0); """ for idx in range(len(op_mutable_attribute_name_list)): attr_name = op_mutable_attribute_name_list[idx] @@ -177,7 +190,7 @@ def GenBuildInserFullForMutableAttribute( def GenBuildInputs(op_input_name_list, op_mutable_attribute_name_list): - BUILD_INPUT_TEMPLATE = """ std::vector argument_inputs = {{{inputs_args}}}; + BUILD_INPUT_TEMPLATE = """ std::vector argument_inputs = {{{inputs_args}}}; argument.AddOperands(argument_inputs.begin(), argument_inputs.end()); """ build_input_str = ' VLOG(4) << "Builder construction inputs";\n' @@ -194,24 +207,25 @@ def GenBuildInputs(op_input_name_list, op_mutable_attribute_name_list): def GenBuildAttributes( op_non_mutable_attribute_name_list, op_non_mutable_attribute_type_list ): - INTARRAY_STR_TEMPLATE = """ ir::Attribute attr_{attr_name} = {op_attribute_type}::get(ir::IrContext::Instance(), phi::IntArray({attr})); + INTARRAY_STR_TEMPLATE = """ pir::Attribute attr_{attr_name} = {op_attribute_type}::get(pir::IrContext::Instance(), phi::IntArray({attr})); """ - SCALAR_STR_TEMPLATE = """ ir::Attribute attr_{attr_name} = paddle::dialect::TransToIrAttribute({attr}, ir::IrContext::Instance()); + SCALAR_STR_TEMPLATE = """ pir::Attribute attr_{attr_name} = paddle::dialect::TransToIrAttribute({attr}, pir::IrContext::Instance()); """ - STR_TEMPLATE = """ ir::Attribute attr_{attr_name} = {op_attribute_type}::get(ir::IrContext::Instance(), {attr}); + STR_TEMPLATE = """ pir::Attribute attr_{attr_name} = {op_attribute_type}::get(pir::IrContext::Instance(), {attr}); """ - ARRAY_ATTRIBUTE_TEMPLATE = """ std::vector vec_{attr_name}; + ARRAY_ATTRIBUTE_TEMPLATE = """ std::vector vec_{attr_name}; for (size_t i = 0; i < static_cast({attr_size}); i++) {{ {create_attribute} vec_{attr_name}.push_back(attr_{attr_name}); }} - ir::Attribute attr_{attr_name} = ir::ArrayAttribute::get(ir::IrContext::Instance(), vec_{attr_name}); + pir::Attribute attr_{attr_name} = pir::ArrayAttribute::get(pir::IrContext::Instance(), vec_{attr_name}); """ attr_str = ' VLOG(4) << "Builder construction attributes";\n' + array_attr_type = "pir::ArrayAttribute<" for idx in range(len(op_non_mutable_attribute_name_list)): - if "ir::ArrayAttribute<" in op_non_mutable_attribute_type_list[idx]: + if array_attr_type in op_non_mutable_attribute_type_list[idx]: inner_attribute_type = op_non_mutable_attribute_type_list[idx][ - 19:-1 + len(array_attr_type) : -1 ] if inner_attribute_type == "paddle::dialect::IntArrayAttribute": attr_str += ARRAY_ATTRIBUTE_TEMPLATE.format( @@ -280,12 +294,15 @@ def GenBuildOutputs( op_class_name, op_input_name_list, op_input_type_list, + op_input_optional_list, op_mutable_attribute_name_list, op_mutable_attribute_type_list, op_output_name_list, op_output_type_list, op_output_size_list, + op_output_optional_list, op_infer_meta_map, + op_inplace_map, mutable_attr_is_input=False, ): build_output_str = ' VLOG(4) << "Builder construction outputs";\n' @@ -299,6 +316,23 @@ def GenBuildOutputs( VLOG(4) << "Builder construction meta_{name}"; phi::MetaTensor meta_{name}(&ir_meta_tensor_{name}); """ + + CREATE_OPTIONAL_INPUT_METATENSOR_TEMPLATE = """ + phi::MetaTensor meta_{name}; + if ({name}_.impl() != nullptr) {{ + paddle::dialect::DenseTensorType {name} = {name}_.type().dyn_cast(); + VLOG(4) << "Builder construction dense_{name}"; + paddle::dialect::IrMetaTensor ir_meta_tensor_{name}(paddle::dialect::TransToPhiDataType({name}.dtype()), + {name}.dims(), + {name}.data_layout(), + {name}.lod(), + {name}.offset()); + VLOG(4) << "Builder construction meta_{name}"; + meta_{name} = phi::MetaTensor(&ir_meta_tensor_{name}); + }} + +""" + CREATE_INPUT_VEC_METATENSOR_TEMPLATE = """ std::vector vec_ir_meta_tensor_{name}; for (size_t i=0; i < static_cast({name}.size()); i++) {{ vec_ir_meta_tensor_{name}.push_back(paddle::dialect::IrMetaTensor(paddle::dialect::TransToPhiDataType({name}[i].dyn_cast().dtype()), @@ -322,7 +356,7 @@ def GenBuildOutputs( CREATE_SCALAR_MUTABLE_ATTRIBUE_TEMPLATE = """ {dtype} {name} = {name}_.owner()->dyn_cast().attributes().at("value").dyn_cast().data().to<{dtype}>(); (void){name};\n""" CREATE_INTARRAY_MUTABLE_ATTRIBUE_WITH_UNKONW_DATA_TEMPLATE = """ phi::IntArray {name}; - if ({name}_.owner()->info().id() == ir::TypeId::get()) {{ + if ({name}_.owner()->info().id() == pir::TypeId::get()) {{ {name} = std::move(phi::IntArray({name}_.owner() ->dyn_cast() .attributes() @@ -330,8 +364,8 @@ def GenBuildOutputs( .dyn_cast() .data() .GetData())); - }} else if ({name}_.type().isa()) {{ - size_t {name}_size = {name}_.type().dyn_cast().size(); + }} else if ({name}_.type().isa()) {{ + size_t {name}_size = {name}_.type().dyn_cast().size(); {name} = std::move(phi::IntArray(std::vector({name}_size, -1))); {name}.SetFromTensor(true); }} else if ({name}_.type().isa()) {{ @@ -343,7 +377,7 @@ def GenBuildOutputs( }}\n""" CREATE_SCALAR_MUTABLE_ATTRIBUE_WITH_UNKONW_DATA_TEMPLATE = """ phi::Scalar {name}; - if ({name}_.owner()->info().id() == ir::TypeId::get()) {{ + if ({name}_.owner()->info().id() == pir::TypeId::get()) {{ {name} = std::move(phi::Scalar({name}_.owner() ->dyn_cast() .attributes() @@ -373,15 +407,16 @@ def GenBuildOutputs( # Prepar input type for idx in range(len(op_input_name_list)): # is a vector - if 'ir::VectorType' in op_input_type_list[idx]: - build_output_str += " ir::VectorType {name} = {name}_.type().dyn_cast(); (void){name};\n".format( + if 'pir::VectorType' in op_input_type_list[idx]: + build_output_str += " pir::VectorType {name} = {name}_.type().dyn_cast(); (void){name};\n".format( name=op_input_name_list[idx] ) # is a Tensor else: - build_output_str += " paddle::dialect::DenseTensorType {name} = {name}_.type().dyn_cast(); (void){name};\n".format( - name=op_input_name_list[idx] - ) + if op_input_optional_list[idx] == 'false': + build_output_str += " paddle::dialect::DenseTensorType {name} = {name}_.type().dyn_cast(); (void){name};\n".format( + name=op_input_name_list[idx] + ) # Prepare mutable attributes if mutable_attr_is_input: @@ -414,7 +449,7 @@ def GenBuildOutputs( ) ) # string - elif attr_dtype[0] == "ir::StrAttribute": + elif attr_dtype[0] == "pir::StrAttribute": build_output_str += "" else: assert "mutable attribtue type is not right." @@ -430,7 +465,7 @@ def GenBuildOutputs( ) not in infer_meta_args: # is a vector if ( - 'ir::VectorType' + 'pir::VectorType' in op_input_type_list[ op_input_name_list.index( op_infer_meta_map['param'][idx] @@ -444,9 +479,21 @@ def GenBuildOutputs( ) # is a Tensor else: - build_output_str += CREATE_INPUT_METATENSOR_TEMPLATE.format( - name=op_infer_meta_map['param'][idx] + input_index = op_input_name_list.index( + op_infer_meta_map['param'][idx] ) + if op_input_optional_list[input_index] == 'true': + build_output_str += ( + CREATE_OPTIONAL_INPUT_METATENSOR_TEMPLATE.format( + name=op_infer_meta_map['param'][idx] + ) + ) + else: + build_output_str += ( + CREATE_INPUT_METATENSOR_TEMPLATE.format( + name=op_infer_meta_map['param'][idx] + ) + ) infer_meta_args.append("meta_" + op_infer_meta_map['param'][idx]) # is attribute @@ -456,7 +503,7 @@ def GenBuildOutputs( # Prepare outputs_meta_tensor for infer meta for idx in range(len(op_output_name_list)): # is a vector - if 'ir::VectorType' in op_output_type_list[idx]: + if 'pir::VectorType' in op_output_type_list[idx]: build_output_str += CREATE_OUTPUT_VEC_METATENSOR_TEMPLATE.format( name=op_output_name_list[idx], output_size=op_output_size_list[idx], @@ -488,32 +535,58 @@ def GenBuildOutputs( ) # use dense_{name} or vec_dense_{name} to create Outputs type - build_output_str += "\n std::vector argument_outputs;" + build_output_str += "\n std::vector argument_outputs;" CREATE_OUTPUT_DENSE_TENSOR_TEMPLATE = """ - ir::Type {name}_dense_tensor_type = paddle::dialect::DenseTensorType::get(ir::IrContext::Instance(), paddle::dialect::TransToIrDataType(dense_{name}.dtype()), dense_{name}.dims(), dense_{name}.layout(), dense_{name}.lod(), dense_{name}.offset()); + pir::Type {name}_dense_tensor_type = paddle::dialect::DenseTensorType::get(pir::IrContext::Instance(), paddle::dialect::TransToIrDataType(dense_{name}.dtype()), dense_{name}.dims(), dense_{name}.layout(), dense_{name}.lod(), dense_{name}.offset()); argument_outputs.push_back({name}_dense_tensor_type); """ + + CREATE_OUTPUT_INPLACE_OPTIONAL_DENSE_TENSOR_TEMPLATE = """ + if ({input_name}_.impl() != nullptr) {{ + pir::Type {output_name}_dense_tensor_type = paddle::dialect::DenseTensorType::get(pir::IrContext::Instance(), paddle::dialect::TransToIrDataType(dense_{output_name}.dtype()), dense_{output_name}.dims(), dense_{output_name}.layout(), dense_{output_name}.lod(), dense_{output_name}.offset()); + argument_outputs.push_back({output_name}_dense_tensor_type); + }} else {{ + pir::Type {output_name}_type; + argument_outputs.push_back({output_name}_type); + }} + +""" + CREATE_OUTPUT_VEC_DENSE_TENSOR_TEMPLATE = """ - std::vector {name}_types; + std::vector {name}_types; for (size_t i=0; i < static_cast({output_size}); i++) {{ - {name}_types.push_back(paddle::dialect::DenseTensorType::get(ir::IrContext::Instance(), paddle::dialect::TransToIrDataType(vec_dense_{name}[i].dtype()), vec_dense_{name}[i].dims(), vec_dense_{name}[i].layout(), vec_dense_{name}[i].lod(), vec_dense_{name}[i].offset())); + {name}_types.push_back(paddle::dialect::DenseTensorType::get(pir::IrContext::Instance(), paddle::dialect::TransToIrDataType(vec_dense_{name}[i].dtype()), vec_dense_{name}[i].dims(), vec_dense_{name}[i].layout(), vec_dense_{name}[i].lod(), vec_dense_{name}[i].offset())); }} - ir::Type {name}_vector_type = ir::VectorType::get(ir::IrContext::Instance(), {name}_types); + pir::Type {name}_vector_type = pir::VectorType::get(pir::IrContext::Instance(), {name}_types); argument_outputs.push_back({name}_vector_type); """ for idx in range(len(op_output_name_list)): # is a vector - if 'ir::VectorType' in op_output_type_list[idx]: + if 'pir::VectorType' in op_output_type_list[idx]: build_output_str += CREATE_OUTPUT_VEC_DENSE_TENSOR_TEMPLATE.format( name=op_output_name_list[idx], output_size=op_output_size_list[idx], ) # is a Tensor else: - build_output_str += CREATE_OUTPUT_DENSE_TENSOR_TEMPLATE.format( - name=op_output_name_list[idx] + output_name = op_output_name_list[idx] + has_input_inplace = ( + op_inplace_map is not None + and output_name in op_inplace_map.keys() ) + if op_output_optional_list[idx] == 'true' and has_input_inplace: + # is a inplace optional output + build_output_str += ( + CREATE_OUTPUT_INPLACE_OPTIONAL_DENSE_TENSOR_TEMPLATE.format( + input_name=op_inplace_map[output_name], + output_name=output_name, + ) + ) + else: + build_output_str += CREATE_OUTPUT_DENSE_TENSOR_TEMPLATE.format( + name=output_name + ) build_output_str += " argument.AddOutputs(argument_outputs.begin(), argument_outputs.end());\n" @@ -524,6 +597,7 @@ def gen_build_func_str( op_class_name, op_input_name_list, op_input_type_list, + op_input_optional_list, op_attribute_name_list, op_attribute_type_list, op_attribute_build_arg_type_list, @@ -537,7 +611,9 @@ def gen_build_func_str( op_output_name_list, op_output_type_list, op_output_size_list, + op_output_optional_list, op_infer_meta_map, + op_inplace_map, muta_attr_is_input=False, attr_args_is_map=False, ): @@ -593,35 +669,59 @@ def gen_build_func_str( op_class_name, op_input_name_list, op_input_type_list, + op_input_optional_list, op_mutable_attribute_name_list, op_mutable_attribute_type_list, op_output_name_list, op_output_type_list, op_output_size_list, + op_output_optional_list, op_infer_meta_map, + op_inplace_map, muta_attr_is_input, ) GET_ATTRIBUTES_FROM_MAP_TEMPLATE = """ + PADDLE_ENFORCE( + attributes.find("{attribute_name}") != attributes.end(), + phi::errors::NotFound( + "'{attribute_name}' Attribute is expected for {op_name}. ")); {attr_type} {attribute_name} = attributes.at("{attribute_name}").dyn_cast<{attr_ir_type}>().data(); """ GET_STR_ATTRIBUTES_FROM_MAP_TEMPLATE = """ - {attr_type} {attribute_name} = attributes.at("{attribute_name}").dyn_cast().AsString(); + PADDLE_ENFORCE( + attributes.find("{attribute_name}") != attributes.end(), + phi::errors::NotFound( + "'{attribute_name}' Attribute is expected for {op_name}. ")); + {attr_type} {attribute_name} = attributes.at("{attribute_name}").dyn_cast().AsString(); """ GET_ARRAY_ATTRIBUTE_FROM_MAP_TEMPLATE = """ + PADDLE_ENFORCE( + attributes.find("{attribute_name}") != attributes.end(), + phi::errors::NotFound( + "'{attribute_name}' Attribute is expected for {op_name}. ")); {attr_type} {attribute_name}; - for (size_t i = 0; i < attributes.at("{attribute_name}").dyn_cast().size(); i++) {{ - {attribute_name}.push_back(attributes.at("{attribute_name}").dyn_cast().at(i).dyn_cast<{inner_type}>().{data_name}()); + for (size_t i = 0; i < attributes.at("{attribute_name}").dyn_cast().size(); i++) {{ + {attribute_name}.push_back(attributes.at("{attribute_name}").dyn_cast().at(i).dyn_cast<{inner_type}>().{data_name}()); }} """ GET_INTARRAY_ATTRIBUTE_FROM_MAP_TEMPLATE = """ + PADDLE_ENFORCE( + attributes.find("{attribute_name}") != attributes.end(), + phi::errors::NotFound( + "'{attribute_name}' Attribute is expected for {op_name}. ")); {attr_type} {attribute_name} = attributes.at("{attribute_name}").dyn_cast().data().GetData(); """ GET_SCALAR_ATTRIBUTE_FROM_MAP_TEMPLATE = """ + PADDLE_ENFORCE( + attributes.find("{attribute_name}") != attributes.end(), + phi::errors::NotFound( + "'{attribute_name}' Attribute is expected for {op_name}. ")); {attr_type} {attribute_name} = attributes.at("{attribute_name}").dyn_cast().data().to<{attr_type}>(); """ get_attributes_str = "" + array_attr_str = "pir::ArrayAttribute" if attr_args_is_map: for idx in range(len(op_attribute_name_list)): attr_type = op_attribute_build_arg_type_list[idx] @@ -629,13 +729,17 @@ def gen_build_func_str( attr_type = attr_type.replace("&", "") # if op_attribute_build_arg_type_list[idx] == "const std::vector&": # attr_type = "std::vector" - if "ir::ArrayAttribute" in op_attribute_type_list[idx]: - inner_type = op_attribute_type_list[idx][19:-1] + + if array_attr_str in op_attribute_type_list[idx]: + inner_type = op_attribute_type_list[idx][ + len(array_attr_str) + 1 : -1 + ] data_name = "data" - if inner_type == "ir::StrAttribute": + if inner_type == "pir::StrAttribute": data_name = "AsString" get_attributes_str += ( GET_ARRAY_ATTRIBUTE_FROM_MAP_TEMPLATE.format( + op_name=op_class_name, attr_type=attr_type, attribute_name=op_attribute_name_list[idx], inner_type=inner_type, @@ -648,6 +752,7 @@ def gen_build_func_str( ): get_attributes_str += ( GET_INTARRAY_ATTRIBUTE_FROM_MAP_TEMPLATE.format( + op_name=op_class_name, attr_type=attr_type, attribute_name=op_attribute_name_list[idx], ) @@ -658,13 +763,15 @@ def gen_build_func_str( ): get_attributes_str += ( GET_SCALAR_ATTRIBUTE_FROM_MAP_TEMPLATE.format( + op_name=op_class_name, attr_type=attr_type, attribute_name=op_attribute_name_list[idx], ) ) - elif "ir::StrAttribute" in op_attribute_type_list[idx]: + elif "pir::StrAttribute" in op_attribute_type_list[idx]: get_attributes_str += ( GET_STR_ATTRIBUTES_FROM_MAP_TEMPLATE.format( + op_name=op_class_name, attr_type=attr_type, attribute_name=op_attribute_name_list[idx], attr_ir_type=op_attribute_type_list[idx], @@ -672,6 +779,7 @@ def gen_build_func_str( ) else: get_attributes_str += GET_ATTRIBUTES_FROM_MAP_TEMPLATE.format( + op_name=op_class_name, attr_type=attr_type, attribute_name=op_attribute_name_list[idx], attr_ir_type=op_attribute_type_list[idx], diff --git a/paddle/fluid/ir/dialect/op_generator/op_gen.py b/paddle/fluid/pir/dialect/op_generator/op_gen.py similarity index 92% rename from paddle/fluid/ir/dialect/op_generator/op_gen.py rename to paddle/fluid/pir/dialect/op_generator/op_gen.py index 8663d23059d45..d52bf901ae17f 100644 --- a/paddle/fluid/ir/dialect/op_generator/op_gen.py +++ b/paddle/fluid/pir/dialect/op_generator/op_gen.py @@ -15,6 +15,8 @@ import argparse import logging import os +import pathlib +import sys import yaml from op_build_gen import gen_build_func_str @@ -30,6 +32,12 @@ vjp_interface_implementation_gen_op_list, ) +# import from paddle/fluid/primitive/code_gen/gen.py +sys.path.append( + str(pathlib.Path(__file__).resolve().parents[3] / 'primitive/codegen') +) +import gen as vjp_gen + # ===================================== # String Template for h file code gen # ===================================== @@ -41,22 +49,23 @@ #undef GET_OP_LIST {op_declare} #else -// This file is generated by "paddle/fluid/ir/dialect/op_generator/op_gen.py" +// This file is generated by "paddle/fluid/pir/dialect/op_generator/op_gen.py" #include -#include "paddle/ir/core/builder.h" -#include "paddle/ir/core/operation_utils.h" -#include "paddle/ir/core/op_base.h" -#include "paddle/fluid/ir/dialect/paddle_dialect/utils/utils.h" -#include "paddle/fluid/ir/dialect/paddle_dialect/utils/op_yaml_info_util.h" -#include "paddle/fluid/ir/dialect/paddle_dialect/interface/op_yaml_info.h" -#include "paddle/fluid/ir/dialect/paddle_dialect/interface/infermeta.h" -#include "paddle/fluid/ir/dialect/paddle_dialect/interface/vjp.h" -#include "paddle/fluid/ir/dialect/paddle_dialect/trait/inplace.h" +#include "paddle/pir/core/builder.h" +#include "paddle/pir/core/operation_utils.h" +#include "paddle/pir/core/op_base.h" +#include "paddle/fluid/pir/dialect/operator/utils/utils.h" +#include "paddle/fluid/pir/dialect/operator/utils/op_yaml_info_util.h" +#include "paddle/fluid/pir/dialect/operator/interface/op_yaml_info.h" +#include "paddle/fluid/pir/dialect/operator/interface/infermeta.h" +#include "paddle/fluid/pir/dialect/operator/interface/vjp.h" +#include "paddle/fluid/pir/dialect/operator/trait/inplace.h" +#include "paddle/fluid/pir/dialect/operator/trait/custom_vjp.h" #include "paddle/fluid/framework/infershape_utils.h" #include "paddle/phi/core/infermeta_utils.h" -#include "paddle/fluid/ir/dialect/paddle_dialect/ir/pd_manual_op.h" +#include "paddle/fluid/pir/dialect/operator/ir/manual_op.h" {input} @@ -72,7 +81,7 @@ """ OP_DECLARE_TEMPLATE = """ -class {op_name} : public ir::Op<{op_name}{interfaces}{traits}> {{ +class {op_name} : public pir::Op<{op_name}{interfaces}{traits}> {{ public: using Op::Op; static const char *name() {{ return "{dialect_op_name}"; }} @@ -97,15 +106,15 @@ class {op_name} : public ir::Op<{op_name}{interfaces}{traits}> {{ # ===================================== # String Template for cc file code gen # ===================================== -CC_FILE_TEMPLATE = """// This file is generated by "paddle/fluid/ir/dialect/op_generator/op_gen.py" +CC_FILE_TEMPLATE = """// This file is generated by "paddle/fluid/pir/dialect/op_generator/op_gen.py" #include "{h_file}" -#include "paddle/fluid/ir/dialect/paddle_dialect/ir/pd_type.h" -#include "paddle/fluid/ir/dialect/paddle_dialect/ir/pd_attribute.h" -#include "paddle/fluid/ir/dialect/paddle_dialect/ir/pd_meta_tensor.h" -#include "paddle/ir/core/builtin_attribute.h" -#include "paddle/ir/core/builtin_type.h" -#include "paddle/ir/core/builtin_op.h" -#include "paddle/ir/core/ir_context.h" +#include "paddle/fluid/pir/dialect/operator/ir/op_type.h" +#include "paddle/fluid/pir/dialect/operator/ir/op_attribute.h" +#include "paddle/fluid/pir/dialect/operator/ir/meta_tensor.h" +#include "paddle/pir/core/builtin_attribute.h" +#include "paddle/pir/core/builtin_type.h" +#include "paddle/pir/core/builtin_op.h" +#include "paddle/pir/core/ir_context.h" #include "paddle/phi/core/enforce.h" #include "paddle/phi/core/dense_tensor.h" #include "paddle/phi/infermeta/binary.h" @@ -117,7 +126,7 @@ class {op_name} : public ir::Op<{op_name}{interfaces}{traits}> {{ #include "paddle/phi/infermeta/fusion.h" #include "paddle/phi/api/lib/utils/allocator.h" #include "paddle/fluid/primitive/rule/vjp/vjp.h" -#include "paddle/ir/core/op_base.h" +#include "paddle/pir/core/op_base.h" {input} @@ -126,13 +135,13 @@ class {op_name} : public ir::Op<{op_name}{interfaces}{traits}> {{ # ===================================== # String Template for pd_op_vjp.cc file code gen # ===================================== -VJP_CC_FILE_TEMPLATE = """// This file is generated by "paddle/fluid/ir/dialect/op_generator/op_gen.py" -#include "paddle/fluid/ir/dialect/paddle_dialect/ir/pd_attribute.h" -#include "paddle/fluid/ir/dialect/paddle_dialect/ir/pd_op.h" +VJP_CC_FILE_TEMPLATE = """// This file is generated by "paddle/fluid/pir/dialect/op_generator/op_gen.py" +#include "paddle/fluid/pir/dialect/operator/ir/op_attribute.h" +#include "paddle/fluid/pir/dialect/operator/ir/pd_op.h" #include "paddle/fluid/primitive/rule/vjp/vjp.h" #include "paddle/fluid/primitive/type/lazy_tensor.h" -#include "paddle/ir/core/builtin_op.h" -#include "paddle/ir/core/op_base.h" +#include "paddle/pir/core/builtin_op.h" +#include "paddle/pir/core/op_base.h" #include "paddle/phi/common/int_array.h" namespace paddle {{ @@ -166,14 +175,14 @@ class {op_name} : public ir::Op<{op_name}{interfaces}{traits}> {{ """ scalar_type_maps = { - 'int': 'ir::Int32Attribute', - 'int64_t': 'ir::Int64Attribute', - 'float': 'ir::FloatAttribute', - 'dobule': 'ir::DoubleAttribute', - 'bool': 'ir::BoolAttribute', + 'int': 'pir::Int32Attribute', + 'int64_t': 'pir::Int64Attribute', + 'float': 'pir::FloatAttribute', + 'dobule': 'pir::DoubleAttribute', + 'bool': 'pir::BoolAttribute', } -_NO_NEED_GEN_OPS = {'add_n', 'add_n_', 'add_n_with_kernel', 'split_grad'} +PD_MANUAL_OP_LIST = {'add_n', 'add_n_', 'add_n_with_kernel', 'split_grad'} def to_phi_and_fluid_op_name(op_item): @@ -255,33 +264,33 @@ def __init__(self, op_yaml_item, op_compat_item): self.attr_types_map = { 'IntArray': ['paddle::dialect::IntArrayAttribute', 'IntArray'], 'Scalar': ['paddle::dialect::ScalarAttribute', 'Scalar'], - 'Scalar(int)': ['ir::Int32Attribute', 'int'], - 'Scalar(int64_t)': ['ir::Int64Attribute', 'int64_t'], - 'Scalar(float)': ['ir::FloatAttribute', 'float'], - 'Scalar(dobule)': ['ir::DoubleAttribute', 'dobule'], + 'Scalar(int)': ['pir::Int32Attribute', 'int'], + 'Scalar(int64_t)': ['pir::Int64Attribute', 'int64_t'], + 'Scalar(float)': ['pir::FloatAttribute', 'float'], + 'Scalar(dobule)': ['pir::DoubleAttribute', 'dobule'], 'Scalar[]': [ - 'ir::ArrayAttribute', + 'pir::ArrayAttribute', 'const std::vector&', ], - 'int': ['ir::Int32Attribute', 'int'], - 'int32_t': ['ir::Int32Attribute', 'int32_t'], - 'int64_t': ['ir::Int64Attribute', 'int64_t'], - 'long': ['ir::LongAttribute', 'long'], - 'size_t': ['ir::Size_tAttribute', 'size_t'], - 'float': ['ir::FloatAttribute', 'float'], + 'int': ['pir::Int32Attribute', 'int'], + 'int32_t': ['pir::Int32Attribute', 'int32_t'], + 'int64_t': ['pir::Int64Attribute', 'int64_t'], + 'long': ['pir::LongAttribute', 'long'], + 'size_t': ['pir::Size_tAttribute', 'size_t'], + 'float': ['pir::FloatAttribute', 'float'], 'float[]': [ - 'ir::ArrayAttribute', + 'pir::ArrayAttribute', 'const std::vector&', ], - 'double': ['ir::DoubleAttribute', 'double'], - 'bool': ['ir::BoolAttribute', 'bool'], + 'double': ['pir::DoubleAttribute', 'double'], + 'bool': ['pir::BoolAttribute', 'bool'], 'bool[]': [ - 'ir::ArrayAttribute', + 'pir::ArrayAttribute', 'const std::vector&', ], - 'str': ['ir::StrAttribute', 'const std::string&'], + 'str': ['pir::StrAttribute', 'const std::string&'], 'str[]': [ - 'ir::ArrayAttribute', + 'pir::ArrayAttribute', 'const std::vector&', ], 'Place': ['paddle::dialect::PlaceAttribute', 'const Place&'], @@ -291,11 +300,11 @@ def __init__(self, op_yaml_item, op_compat_item): ], 'DataType': ['paddle::dialect::DataTypeAttribute', 'DataType'], 'int64_t[]': [ - 'ir::ArrayAttribute', + 'pir::ArrayAttribute', 'const std::vector&', ], 'int[]': [ - 'ir::ArrayAttribute', + 'pir::ArrayAttribute', 'const std::vector&', ], } @@ -517,7 +526,7 @@ def parse_input_name_list(self): def parse_input_type_list(self): input_types_map = { 'Tensor': 'paddle::dialect::DenseTensorType', - 'Tensor[]': 'ir::VectorType', + 'Tensor[]': 'pir::VectorType', } type_list = [] for input_info in self.op_yaml_item['inputs']: @@ -554,7 +563,7 @@ def parse_output_name_list(self): def parse_output_type_list(self): output_type_map = { 'Tensor': 'paddle::dialect::DenseTensorType', - 'Tensor[]': 'ir::VectorType', + 'Tensor[]': 'pir::VectorType', 'SelectedRows': 'paddle::dialect::SelectedRowsType', } type_list = [] @@ -818,6 +827,12 @@ def OpGenerator( ops_declare_list = [] # all op class declare store in this list ops_defined_list = [] # all op class defined store in this list ops_vjp_defined_list = [] # all op vjp static interface defination + + # (4) parse name of ops which have custom vjp rules + custom_vjp_op_name_list = [] + for custom_vjp in vjp_gen.CUSTOM_VJP: + custom_vjp_op_name_list.append(custom_vjp[:-5]) # cut _grad + for key, op_info in op_info_items.items(): # get op inputs info op_input_name_list = op_info.input_name_list @@ -873,6 +888,10 @@ def OpGenerator( op_interfaces += ["paddle::dialect::VjpInterface"] exclusive_interface_str = gen_exclusive_interface_str(op_info) + # if op has custom vjp rule, then append a CustomVjpTrait to it + if op_info.op_phi_name[0] in custom_vjp_op_name_list: + op_traits += ["paddle::dialect::CustomVjpTrait"] + # check op inputs and mutable_attributes grad semantics input_grad_semantics = get_input_grad_semantic(op_info, op_info_items) mutable_attribute_grad_semantics = get_mutable_attribute_grad_semantic( @@ -881,7 +900,7 @@ def OpGenerator( # If op has inplace info, we will generate inplace op and non-inplace op. for op_name in op_info.op_phi_name: - if op_name in _NO_NEED_GEN_OPS: + if op_name in PD_MANUAL_OP_LIST: continue op_class_name = to_pascal_case(op_name) + "Op" op_dialect_name = dialect_name + "." + op_name @@ -927,6 +946,7 @@ def OpGenerator( op_class_name, op_input_name_list, op_input_type_list, + op_input_optional_list, op_attribute_name_list, op_attribute_type_list, op_attribute_build_arg_type_list, @@ -940,7 +960,9 @@ def OpGenerator( op_output_name_list, op_output_type_list, op_output_size_list, + op_output_optional_list, op_infer_meta_map, + op_inplace_map, muta_attr_is_input=False, ) if len(op_attribute_name_list) > 0: @@ -951,6 +973,7 @@ def OpGenerator( op_class_name, op_input_name_list, op_input_type_list, + op_input_optional_list, op_attribute_name_list, op_attribute_type_list, op_attribute_build_arg_type_list, @@ -964,7 +987,9 @@ def OpGenerator( op_output_name_list, op_output_type_list, op_output_size_list, + op_output_optional_list, op_infer_meta_map, + op_inplace_map, muta_attr_is_input=False, attr_args_is_map=True, ) @@ -982,6 +1007,7 @@ def OpGenerator( op_class_name, op_input_name_list, op_input_type_list, + op_input_optional_list, op_attribute_name_list, op_attribute_type_list, op_attribute_build_arg_type_list, @@ -995,7 +1021,9 @@ def OpGenerator( op_output_name_list, op_output_type_list, op_output_size_list, + op_output_optional_list, op_infer_meta_map, + op_inplace_map, muta_attr_is_input=True, ) @@ -1188,7 +1216,6 @@ def OpGenerator( if dialect_name == "cinn": logging.warning("cinn is currently not support Vjp function") else: - # TODO(chenzhiyang) add vjp gen code if ( op_info.backward_name and op_info.op_phi_name[0] diff --git a/paddle/fluid/ir/dialect/op_generator/op_interface_gen.py b/paddle/fluid/pir/dialect/op_generator/op_interface_gen.py similarity index 55% rename from paddle/fluid/ir/dialect/op_generator/op_interface_gen.py rename to paddle/fluid/pir/dialect/op_generator/op_interface_gen.py index 2490335f6c3fb..db763146fb1d3 100644 --- a/paddle/fluid/ir/dialect/op_generator/op_interface_gen.py +++ b/paddle/fluid/pir/dialect/op_generator/op_interface_gen.py @@ -26,14 +26,20 @@ {input_type} {input_name}(std::make_shared(op_obj.{input_name}()));""" OP_VJP_FORWARD_MULTI_INPUT_TEMPLATE = """ - ir::CombineOp combine_op_obj = - op_obj.{input_name}().GetDefiningOp()->dyn_cast(); + pir::CombineOp combine_op_obj = + op_obj.{input_name}().GetDefiningOp()->dyn_cast(); std::vector {input_name}; for (size_t idx = 0; idx < combine_op_obj.inputs().size(); idx++) {{ {input_name}.emplace_back( std::make_shared(combine_op_obj.inputs()[idx])); }}""" +OP_VJP_FORWARD_OPTIONAL_INPUT_TEMPLATE = """ + paddle::optional {input_name}; + if (op_obj.{input_name}().type().storage()){{ + {input_name} = paddle::make_optional(Tensor(std::make_shared(op_obj.{input_name}()))); + }}""" + OP_VJP_FORWARD_OUTPUT_GRAD_TEMPLATE = """ Tensor {output_grad_name}(std::make_shared(out_grads[{idx1}][{idx2}]));""" @@ -50,6 +56,11 @@ OP_VJP_ATTRIBUTE_DEFAULT_TEMPLATE = """ {attr_type} {attr_name} = {default_value};""" +OP_VJP_ATTRIBUTE_ARRAY_TEMPLATE = """ + {attr_type} {attr_name}; + for (size_t i = 0; i < op->attribute("{attr_name}").dyn_cast().size(); i++) {{ + {attr_name}.push_back(op->attribute("{attr_name}").dyn_cast().at(i).dyn_cast<{inner_type}>().{func}()); + }}""" OP_VJP_CALL_VJP_TEMPLATE = """ std::vector> tensor_res = @@ -57,19 +68,19 @@ {inputs_list}stop_gradients);""" OP_VJP_STOPGRADIENT_TEMPLATE = """ - std::vector> res(tensor_res.size()); + std::vector> res(tensor_res.size()); for (size_t i = 0; i < tensor_res.size(); ++i) { res[i].resize(tensor_res[i].size()); for (size_t j = 0; j < tensor_res[i].size(); ++j) { if(tensor_res[i][j].defined()){ - res[i][j] = std::static_pointer_cast(tensor_res[i][j].impl())->getValue().dyn_cast(); + res[i][j] = std::static_pointer_cast(tensor_res[i][j].impl())->value().dyn_cast(); } } }""" OP_VJP_DEFINE_TEMPLATE = """ -std::vector> {op_class_name}::Vjp(ir::Operation* op, const std::vector>& out_grads, const std::vector>& stop_gradients){{ - {op_class_name} op_obj = op->dyn_cast<{op_class_name}>(); +std::vector> {op_class_name}::Vjp(pir::Operation* op, const std::vector>& out_grads, const std::vector>& stop_gradients){{ + {op_class_name} op_obj = op->dyn_cast<{op_class_name}>(); (void)op_obj; VLOG(6) << "Prepare inputs of {op_grad_name}"; {forward_input_output_code} @@ -89,11 +100,7 @@ input_types_map = { 'paddle::dialect::DenseTensorType': 'Tensor', - 'ir::VectorType': 'Tensor[]', -} - -attr_data_map = { - 'ir::StrAttribute': 'AsString', + 'pir::VectorType': 'Tensor[]', } @@ -111,45 +118,53 @@ def gen_op_vjp_str( grad_idx = -1 for idx in range(len(bw_input_list)): build_args_str += bw_input_list[idx] + ", " - if ( - bw_input_list[idx] in op_info.input_name_list - or bw_input_list[idx] in op_info.output_name_list - ): - input_type = input_types_map[op_grad_info.input_type_list[idx]] - if input_type == 'Tensor': - forward_input_output_code += ( - OP_VJP_FORWARD_INPUT_OR_OUTPUT_TEMPLATE.format( - input_type=input_type, - input_name=bw_input_list[idx], - ) - ) - else: - forward_input_output_code += ( - OP_VJP_FORWARD_MULTI_INPUT_TEMPLATE.format( - input_name=bw_input_list[idx], - ) + if op_grad_info.input_optional_list[idx] == 'true': + forward_input_output_code += ( + OP_VJP_FORWARD_OPTIONAL_INPUT_TEMPLATE.format( + input_name=bw_input_list[idx], ) + ) else: - grad_idx += 1 - input_type = input_types_map[op_grad_info.input_type_list[idx]] - if input_type == 'Tensor': - forward_output_grad_code += ( - OP_VJP_FORWARD_OUTPUT_GRAD_TEMPLATE.format( - output_grad_name=bw_input_list[idx], - idx1=grad_idx, - idx2=0, + if ( + bw_input_list[idx] in op_info.input_name_list + or bw_input_list[idx] in op_info.output_name_list + ): + input_type = input_types_map[op_grad_info.input_type_list[idx]] + if input_type == 'Tensor': + forward_input_output_code += ( + OP_VJP_FORWARD_INPUT_OR_OUTPUT_TEMPLATE.format( + input_type=input_type, + input_name=bw_input_list[idx], + ) + ) + else: + forward_input_output_code += ( + OP_VJP_FORWARD_MULTI_INPUT_TEMPLATE.format( + input_name=bw_input_list[idx], + ) ) - ) else: - forward_input_output_code += ( - OP_VJP_FORWARD_OUTPUT_GRAD_LIST_TEMPLATE.format( - output_grad_name=bw_input_list[idx], index=grad_idx + grad_idx += 1 + input_type = input_types_map[op_grad_info.input_type_list[idx]] + if input_type == 'Tensor': + forward_output_grad_code += ( + OP_VJP_FORWARD_OUTPUT_GRAD_TEMPLATE.format( + output_grad_name=bw_input_list[idx], + idx1=grad_idx, + idx2=0, + ) + ) + else: + forward_input_output_code += ( + OP_VJP_FORWARD_OUTPUT_GRAD_LIST_TEMPLATE.format( + output_grad_name=bw_input_list[idx], index=grad_idx + ) ) - ) op_attribute_list = op_grad_info.attribute_name_list attribute_code = '' + build_attr_str = '' + array_attr_str = "pir::ArrayAttribute" for idx in range(len(op_attribute_list)): - build_args_str += op_attribute_list[idx] + ", " if op_attribute_list[idx] in op_info.attribute_name_list: if op_attribute_list[idx] in op_info.mutable_attribute_name_list: attribute_code += ( @@ -158,19 +173,38 @@ def gen_op_vjp_str( input_name=op_attribute_list[idx], ) ) + build_args_str += op_attribute_list[idx] + ", " else: - func = 'data' - if ( - op_grad_info.attribute_type_list[idx] - in attr_data_map.keys() - ): - func = attr_data_map[op_grad_info.attribute_type_list[idx]] - attribute_code += OP_VJP_ATTRIBUTE_TEMPLATE.format( - attr_type=op_grad_info.attribute_gen_arg_type_list[idx], - attr_name=op_attribute_list[idx], - attr_parse_type=op_grad_info.attribute_type_list[idx], - func=func, - ) + func = "data" + attr_type = op_grad_info.attribute_gen_arg_type_list[idx] + attr_type = attr_type.replace("const ", "") + attr_type = attr_type.replace("&", "") + if array_attr_str in op_grad_info.attribute_type_list[idx]: + inner_type = op_grad_info.attribute_type_list[idx][ + len(array_attr_str) + 1 : -1 + ] + func = "data" + if inner_type == "pir::StrAttribute": + func = "AsString" + attribute_code += OP_VJP_ATTRIBUTE_ARRAY_TEMPLATE.format( + attr_type=attr_type, + attr_name=op_attribute_list[idx], + inner_type=inner_type, + func=func, + ) + else: + if ( + op_grad_info.attribute_type_list[idx] + == "pir::StrAttribute" + ): + func = "AsString" + attribute_code += OP_VJP_ATTRIBUTE_TEMPLATE.format( + attr_type=attr_type, + attr_name=op_attribute_list[idx], + attr_parse_type=op_grad_info.attribute_type_list[idx], + func=func, + ) + build_attr_str += op_attribute_list[idx] + ", " else: attribute_code += OP_VJP_ATTRIBUTE_DEFAULT_TEMPLATE.format( @@ -178,6 +212,8 @@ def gen_op_vjp_str( attr_name=op_attribute_list[idx], default_value=op_grad_info.attribute_default_value_list[idx], ) + build_attr_str += op_attribute_list[idx] + ", " + build_args_str += build_attr_str op_phi_name_format = op_phi_name if op_phi_name[-1] == '_': op_phi_name_format = op_phi_name[:-1] @@ -218,5 +254,5 @@ def gen_exclusive_interface_str(op_info): " static void InferMeta( phi::InferMetaContext *infer_meta );" ) if op_info.op_phi_name[0] in vjp_interface_declare_gen_op_list: - exclusive_interface_str += "\n static std::vector> Vjp(ir::Operation* op, const std::vector>& out_grads, const std::vector>& stop_gradients);" + exclusive_interface_str += "\n static std::vector> Vjp(pir::Operation* op, const std::vector>& out_grads, const std::vector>& stop_gradients);" return exclusive_interface_str diff --git a/paddle/fluid/ir/dialect/op_generator/op_member_func_gen.py b/paddle/fluid/pir/dialect/op_generator/op_member_func_gen.py similarity index 88% rename from paddle/fluid/ir/dialect/op_generator/op_member_func_gen.py rename to paddle/fluid/pir/dialect/op_generator/op_member_func_gen.py index 9bc2c75ccf8a9..1cf32a44c5f60 100644 --- a/paddle/fluid/ir/dialect/op_generator/op_member_func_gen.py +++ b/paddle/fluid/pir/dialect/op_generator/op_member_func_gen.py @@ -14,9 +14,9 @@ # generator op member function -OP_GET_INPUT_TEMPLATE = """ ir::Value {input_name}() {{ return operand_source({input_index}); }} +OP_GET_INPUT_TEMPLATE = """ pir::Value {input_name}() {{ return operand_source({input_index}); }} """ -OP_GET_OUTPUT_TEMPLATE = """ ir::OpResult {output_name}() {{ return result({output_index}); }} +OP_GET_OUTPUT_TEMPLATE = """ pir::OpResult {output_name}() {{ return result({output_index}); }} """ diff --git a/paddle/fluid/ir/dialect/op_generator/op_verify_gen.py b/paddle/fluid/pir/dialect/op_generator/op_verify_gen.py similarity index 91% rename from paddle/fluid/ir/dialect/op_generator/op_verify_gen.py rename to paddle/fluid/pir/dialect/op_generator/op_verify_gen.py index 917728f2c8b17..4dffdb2c7b814 100644 --- a/paddle/fluid/ir/dialect/op_generator/op_verify_gen.py +++ b/paddle/fluid/pir/dialect/op_generator/op_verify_gen.py @@ -43,7 +43,7 @@ PADDLE_ENFORCE((*this)->operand_source({index}).type().isa<{standard}>(), phi::errors::PreconditionNotMet("Type validation failed for the {index}th input."));""" INPUT_VECTORTYPE_CHECK_TEMPLATE = """ - if (auto vec_type = (*this)->operand_source({index}).type().dyn_cast()) {{ + if (auto vec_type = (*this)->operand_source({index}).type().dyn_cast()) {{ for (size_t i = 0; i < vec_type.size(); ++i) {{ PADDLE_ENFORCE(vec_type[i].isa<{standard}>(), phi::errors::PreconditionNotMet("Type validation failed for the {index}th input.")); @@ -60,7 +60,7 @@ }}""" INPUT_OPTIONAL_VECTORTYPE_CHECK_TEMPLATE = """ if (auto val = (*this)->operand({index})) {{ - if (auto vec_type = val.type().dyn_cast()) {{ + if (auto vec_type = val.type().dyn_cast()) {{ for (size_t i = 0; i < vec_type.size(); i++) {{ PADDLE_ENFORCE(vec_type[i].isa<{standard}>(), phi::errors::PreconditionNotMet("Type validation failed for the {index}th input.")); @@ -75,10 +75,10 @@ PADDLE_ENFORCE(attributes.count("{attribute_name}")>0 && attributes.at("{attribute_name}").isa<{standard}>(), phi::errors::PreconditionNotMet("Type of attribute: {attribute_name} is not right."));""" ATTRIBUTE_VECTOR_CHECK_TEMPLATE = """ - PADDLE_ENFORCE(attributes.count("{attribute_name}")>0 && attributes.at("{attribute_name}").isa(), + PADDLE_ENFORCE(attributes.count("{attribute_name}")>0 && attributes.at("{attribute_name}").isa(), phi::errors::PreconditionNotMet("Type of attribute: {attribute_name} is not right.")); - for (size_t i = 0; i < attributes.at("{attribute_name}").dyn_cast().size(); i++) {{ - PADDLE_ENFORCE(attributes.at("{attribute_name}").dyn_cast().at(i).isa<{standard}>(), + for (size_t i = 0; i < attributes.at("{attribute_name}").dyn_cast().size(); i++) {{ + PADDLE_ENFORCE(attributes.at("{attribute_name}").dyn_cast().at(i).isa<{standard}>(), phi::errors::PreconditionNotMet("Type of attribute: {attribute_name} is not right.")); }}""" OUTPUT_TYPE_CHECK_TEMPLATE = """ @@ -86,7 +86,7 @@ phi::errors::PreconditionNotMet("Type validation failed for the {index}th output."));""" OUTPUT_VECTORTYPE_CHECK_TEMPLATE = """ auto output_{index}_type = (*this)->result({index}).type(); - if (auto vec_type = output_{index}_type.dyn_cast()) {{ + if (auto vec_type = output_{index}_type.dyn_cast()) {{ for (size_t i = 0; i < vec_type.size(); i++) {{ PADDLE_ENFORCE(vec_type[i].isa<{standard}>(), phi::errors::PreconditionNotMet("Type validation failed for the {index}th output.")); @@ -103,7 +103,7 @@ }}""" OUTPUT_OPTIONAL_VECTORTYPE_CHECK_TEMPLATE = """ if (auto output_{index}_type = (*this)->result({index}).type()) {{ - if (auto vec_type = output_{index}_type.dyn_cast()) {{ + if (auto vec_type = output_{index}_type.dyn_cast()) {{ for (size_t i = 0; i < vec_type.size(); ++i) {{ PADDLE_ENFORCE(vec_type[i].isa<{standard}>(), phi::errors::PreconditionNotMet("Type validation failed for the {index}th output.")); @@ -128,13 +128,14 @@ def gen_inputs_type_check_str( // Inputs num is 0, not need to check inputs type.""" else: inputs_type_check_str = "" + vector_type_str = "pir::VectorType<" for idx in range(len(op_input_type_list)): input_type = op_input_type_list[idx] is_optional = op_input_optional_list[idx] is_vector = False - if input_type.startswith("ir::VectorType<"): + if input_type.startswith(vector_type_str): is_vector = True - input_type = input_type[15:-1] + input_type = input_type[len(vector_type_str) : -1] check_str = "" if is_optional == "true": if is_vector: @@ -182,11 +183,13 @@ def gen_attributes_type_check_str( else: attributes_check_str = """ auto& attributes = this->attributes();""" + array_attr_str = "pir::ArrayAttribute<" for idx in range(len(op_non_mutable_attribute_name_list)): attribute_name = op_non_mutable_attribute_name_list[idx] attribute_type = op_non_mutable_attribute_type_list[idx] - if attribute_type.startswith("ir::ArrayAttribute<"): - attribute_type = attribute_type[19:-1] + + if attribute_type.startswith(array_attr_str): + attribute_type = attribute_type[len(array_attr_str) : -1] attributes_check_str += ATTRIBUTE_VECTOR_CHECK_TEMPLATE.format( attribute_name=attribute_name, standard=attribute_type, @@ -205,13 +208,14 @@ def gen_outputs_type_check_str(op_output_type_list, op_output_optional_list): // Outputs num is 0, not need to check outputs type.""" else: outputs_type_check_str = "" + vector_type_str = "pir::VectorType<" for idx in range(len(op_output_type_list)): output_type = op_output_type_list[idx] is_optional = op_output_optional_list[idx] is_vector = False - if output_type.startswith("ir::VectorType<"): + if output_type.startswith(vector_type_str): is_vector = True - output_type = output_type[15:-1] + output_type = output_type[len(vector_type_str) : -1] check_str = "" if is_optional == "true": if is_vector: diff --git a/paddle/fluid/ir/dialect/op_generator/ops_api_gen.py b/paddle/fluid/pir/dialect/op_generator/ops_api_gen.py similarity index 83% rename from paddle/fluid/ir/dialect/op_generator/ops_api_gen.py rename to paddle/fluid/pir/dialect/op_generator/ops_api_gen.py index bde5c7c23a7bc..9f04a9b2fd4b2 100644 --- a/paddle/fluid/ir/dialect/op_generator/ops_api_gen.py +++ b/paddle/fluid/pir/dialect/op_generator/ops_api_gen.py @@ -15,13 +15,14 @@ import argparse import os -from api_gen import NAMESPACE_TEMPLATE, PD_MANUAL_OP_LIST, CodeGen +from api_gen import NAMESPACE_TEMPLATE, CodeGen CPP_FILE_TEMPLATE = """ #include #include "paddle/fluid/pybind/static_op_function.h" #include "paddle/fluid/pybind/eager_op_function.h" +#include "paddle/fluid/pybind/manual_static_op_function.h" #include "paddle/phi/core/enforce.h" #include "paddle/fluid/eager/api/utils/global_utils.h" @@ -41,6 +42,9 @@ if (PyModule_AddFunctions(module->ptr(), OpsAPI) < 0) {{ PADDLE_THROW(phi::errors::Fatal("Add C++ api to core.ops failed!")); }} + if (PyModule_AddFunctions(module->ptr(), ManualOpsAPI) < 0) {{ + PADDLE_THROW(phi::errors::Fatal("Add C++ api to core.ops failed!")); + }} }} """ @@ -55,7 +59,7 @@ }} }}""" -NO_DY_FUNCTION_IMPL_TEMPLATE = """ +STATIC_ONLY_FUNCTION_IMPL_TEMPLATE = """ static PyObject *{name}(PyObject *self, PyObject *args, PyObject *kwargs) {{ VLOG(6) << "Call static_api_{name}"; return static_api_{name}(self, args, kwargs); @@ -64,8 +68,9 @@ OPS_API_TEMPLATE = """ {{"{name}", (PyCFunction)(void (*)(void)){name}, METH_VARARGS | METH_KEYWORDS, "C++ interface function for {name}."}},""" -SPECIAL_STATIC_ONLY_APIS = [ - 'fetch', +NEED_GEN_STATIC_ONLY_APIS = ['fetch'] + +NO_NEED_GEN_STATIC_ONLY_APIS = [ 'set_value_with_tensor', 'set_value_with_tensor_', 'fused_bn_add_activation_', @@ -86,6 +91,10 @@ 'c_allreduce_sum', 'c_embedding', 'c_identity', + 'c_reduce_sum', + 'c_allreduce_max', + 'c_allgather', + 'seed', ] @@ -93,14 +102,16 @@ class OpsAPIGen(CodeGen): def __init__(self) -> None: super().__init__() + def _need_skip(self, op_info, op_name): + return ( + super()._need_skip(op_info, op_name) + or op_name.endswith(('_grad', '_grad_', 'xpu')) + or op_name in NO_NEED_GEN_STATIC_ONLY_APIS + ) + def _gen_one_function_impl(self, name): - if ( - name.endswith('_grad') - or name.endswith('_grad_') - or name.endswith('xpu') - or name in SPECIAL_STATIC_ONLY_APIS - ): - return NO_DY_FUNCTION_IMPL_TEMPLATE.format(name=name) + if name in NEED_GEN_STATIC_ONLY_APIS: + return STATIC_ONLY_FUNCTION_IMPL_TEMPLATE.format(name=name) else: return FUNCTION_IMPL_TEMPLATE.format(name=name) @@ -117,10 +128,7 @@ def gen_cpp_file( ops_api_str = '' for op_info in op_info_items: for op_name in op_info.op_phi_name: - if ( - op_info.infer_meta_func is None - and op_name not in PD_MANUAL_OP_LIST - ): + if self._need_skip(op_info, op_name): continue function_impl_str += self._gen_one_function_impl(op_name) ops_api_str += self._gen_one_ops_api(op_name) diff --git a/paddle/fluid/ir/dialect/op_generator/python_c_gen.py b/paddle/fluid/pir/dialect/op_generator/python_c_gen.py similarity index 67% rename from paddle/fluid/ir/dialect/op_generator/python_c_gen.py rename to paddle/fluid/pir/dialect/op_generator/python_c_gen.py index a890a8db5d249..62c98bcef9f80 100644 --- a/paddle/fluid/ir/dialect/op_generator/python_c_gen.py +++ b/paddle/fluid/pir/dialect/op_generator/python_c_gen.py @@ -15,13 +15,7 @@ import argparse import re -from api_gen import ( - NAMESPACE_TEMPLATE, - OP_RESULT, - PD_MANUAL_OP_LIST, - VECTOR_TYPE, - CodeGen, -) +from api_gen import NAMESPACE_TEMPLATE, OP_RESULT, VECTOR_TYPE, CodeGen H_FILE_TEMPLATE = """ @@ -46,7 +40,7 @@ CPP_FILE_TEMPLATE = """ #include "paddle/fluid/pybind/static_op_function.h" -#include "paddle/fluid/ir/dialect/paddle_dialect/ir/pd_api.h" +#include "paddle/fluid/pir/dialect/operator/ir/pd_api.h" #include "paddle/fluid/pybind/eager_utils.h" #include "paddle/fluid/pybind/exception.h" #include "paddle/fluid/pybind/op_function_common.h" @@ -125,20 +119,14 @@ {attrs_py_obj} // Check for mutable attrs - bool has_mutable_attr = false; - {check_mutable_attrs} - - if (has_mutable_attr){{ - {cast_attrs_with_mutable} - // Call ir static api - auto static_api_out = paddle::dialect::{api_name}({args_with_mutable_attrs}); - return ToPyObject(static_api_out); - }} else {{ - {cast_attrs_without_mutable} - // Call ir static api - auto static_api_out = paddle::dialect::{api_name}({args_without_mutable_attrs}); - return ToPyObject(static_api_out); - }} + {init_attrs} + {cast_attrs} + + // Call ir static api + auto static_api_out = paddle::dialect::{api_name}({args_with_mutable_attrs}); + return ToPyObject(static_api_out); + + }} catch (...) {{ ThrowExceptionToPython(std::current_exception()); return nullptr; @@ -146,18 +134,40 @@ }} """ -CHECK_MUTABLE_ATTR_TEMPLATE = """ +INIT_ATTRS_TEMPLATE = """ + {type} {name}; +""" +MUTABLE_ATTR_TEMPLATE = """ if (PyObject_CheckIROpResult({name}_obj)){{ - has_mutable_attr = true; + {mutable_cast_attrs} + }}else{{ + {no_mutable_cast_attrs} + }}""" + +MUTABLE_ATTR_LIST_TEMPLATE = """ + if (PyObject_CheckIRVectorOfOpResult({name}_obj)){{ + {mutable_cast_attrs} + }}else{{ + {no_mutable_cast_attrs} }}""" MUTABLE_ATTR_OBJ_TEMPLATE = """ PyObject *{name}_obj = PyTuple_GET_ITEM(args, {index});""" MUTABLE_ATTR_CAST_TEMPLATE = """ - {type} {name} = {cast_func}({name}_obj, "{api_name}", {index});""" + {type} {name_} = {cast_func}({name}_obj, "{api_name}", {index});""" + +FULL_OP_TEMPLATE = """ + {name} = paddle::dialect::full(std::vector{{1}}, {name}_tmp, phi::DataType::{phi_datatype}, phi::CPUPlace()); +""" +FULL_INT_ARRAY_OP_TEMPLATE = """ + {name} = paddle::dialect::full_int_array({name}_tmp, phi::DataType::{phi_datatype}, phi::CPUPlace()); +""" +BUILTIN_COMBINE_OP_TEMPLATE = """ + {name} = paddle::dialect::builtin_combine({name}_tmp); +""" TYPE_TO_FUNC_MAP = { "bool": "CastPyArg2Boolean", "int": "CastPyArg2Int", @@ -181,6 +191,21 @@ "phi::DataType": "CastPyArg2DataTypeDirectly", } +TYPE_TO_PHI_DATATYPE_MAP = { + "bool": "BOOL", + "int": "INT32", + "long": "INT64", + "int64_t": "INT64", + "float": "FLOAT32", + "double": "FLOAT64", + "std::vector": "BOOL", + "std::vector": "INT32", + "std::vector": "INT64", + "std::vector": "INT64", + "std::vector": "FLOAT32", + "std::vector": "FLOAT64", +} + class PythonCCodeGen(CodeGen): def __init__(self) -> None: @@ -195,10 +220,7 @@ def _gen_h_file(self, op_info_items, namespaces, h_file_path): for op_name in op_info.op_phi_name: # NOTE:When infer_meta_func is None, the Build() function generated in pd_op # is wrong, so temporarily skip the automatic generation of these APIs - if ( - op_info.infer_meta_func is None - and op_name not in PD_MANUAL_OP_LIST - ): + if self._need_skip(op_info, op_name): continue declare_str += self._gen_one_declare(op_name) @@ -252,33 +274,101 @@ def _gen_attrs_py_obj_with_mutable(self, op_info): ) return ret - def _gen_check_mutable_attrs(self, op_info): - name_list = op_info.mutable_attribute_name_list + def _gen_init_mutable_attrs(self, op_info): + mutable_attr_name_list = op_info.mutable_attribute_name_list ret = '' - for name in name_list: - ret += CHECK_MUTABLE_ATTR_TEMPLATE.format(name=name) + for name in mutable_attr_name_list: + ret += INIT_ATTRS_TEMPLATE.format(type=OP_RESULT, name=name) + return ret - def _gen_cast_attrs(self, op_info, op_name, with_mutable): + def _gen_cast_attrs(self, op_info, op_name): input_size = len(op_info.input_name_list) attr_name_list = op_info.attribute_name_list attr_type_list = op_info.attribute_build_arg_type_list mutable_attr_name_list = op_info.mutable_attribute_name_list + mutable_attr_type_list = op_info.mutable_attribute_type_list assert len(attr_name_list) == len(attr_type_list) ret = '' for i, (name, type) in enumerate(zip(attr_name_list, attr_type_list)): type = type.replace('const ', '').replace('&', '') cast_func = TYPE_TO_FUNC_MAP[type] - if with_mutable and name in mutable_attr_name_list: - type = OP_RESULT - cast_func = 'CastPyArg2OpResult' - ret += MUTABLE_ATTR_CAST_TEMPLATE.format( - type=type, - name=name, - cast_func=cast_func, - api_name=op_name, - index=input_size + i, - ) + + if name in mutable_attr_name_list: + phi_dtype = TYPE_TO_PHI_DATATYPE_MAP[type] + if ( + mutable_attr_type_list[mutable_attr_name_list.index(name)][ + 0 + ] + == "paddle::dialect::IntArrayAttribute" + ): + mutable_cast_str = MUTABLE_ATTR_CAST_TEMPLATE.format( + type='std::vector', + name_=name + '_tmp', + name=name, + cast_func='CastPyArg2VectorOfOpResult', + api_name=op_name, + index=input_size + i, + ) + mutable_cast_str += BUILTIN_COMBINE_OP_TEMPLATE.format( + name=name + ) + + else: + mutable_cast_str = MUTABLE_ATTR_CAST_TEMPLATE.format( + type='', + name_=name, + name=name, + cast_func='CastPyArg2OpResult', + api_name=op_name, + index=input_size + i, + ) + + no_mutable_cast_str = MUTABLE_ATTR_CAST_TEMPLATE.format( + type=type, + name_=name + '_tmp', + name=name, + cast_func=cast_func, + api_name=op_name, + index=input_size + i, + ) + + if ( + mutable_attr_type_list[mutable_attr_name_list.index(name)][ + 0 + ] + == "paddle::dialect::IntArrayAttribute" + ): + no_mutable_cast_str += FULL_INT_ARRAY_OP_TEMPLATE.format( + name=name, + phi_datatype=phi_dtype, + ) + ret += MUTABLE_ATTR_LIST_TEMPLATE.format( + name=name, + mutable_cast_attrs=mutable_cast_str, + no_mutable_cast_attrs=no_mutable_cast_str, + ) + else: + no_mutable_cast_str += FULL_OP_TEMPLATE.format( + name=name, + phi_datatype=phi_dtype, + ) + ret += MUTABLE_ATTR_TEMPLATE.format( + name=name, + mutable_cast_attrs=mutable_cast_str, + no_mutable_cast_attrs=no_mutable_cast_str, + ) + else: + mutable_cast_str = MUTABLE_ATTR_CAST_TEMPLATE.format( + type=type, + name_=name, + name=name, + cast_func=cast_func, + api_name=op_name, + index=input_size + i, + ) + ret += mutable_cast_str + return ret def _gen_one_impl(self, op_info, op_name): @@ -300,21 +390,13 @@ def _gen_one_impl(self, op_info, op_name): api_name=op_name, inputs=self._gen_inputs(op_info, op_name), attrs_py_obj=self._gen_attrs_py_obj_with_mutable(op_info), - check_mutable_attrs=self._gen_check_mutable_attrs(op_info), - cast_attrs_with_mutable=self._gen_cast_attrs( - op_info, op_name, True - ), + init_attrs=self._gen_init_mutable_attrs(op_info), + cast_attrs=self._gen_cast_attrs(op_info, op_name), args_with_mutable_attrs=', '.join( input_name_list + mutable_attr_name_list + no_mutable_attr_name_list ), - cast_attrs_without_mutable=self._gen_cast_attrs( - op_info, op_name, False - ), - args_without_mutable_attrs=', '.join( - input_name_list + attr_name_list - ), ) else: ret = NO_MUTABLE_ATTR_API_IMPL_TEMPLATE.format( @@ -332,10 +414,7 @@ def _gen_cpp_file(self, op_info_items, namespaces, cpp_file_path): for op_name in op_info.op_phi_name: # NOTE:When infer_meta_func is None, the Build() function generated in pd_op # is wrong, so temporarily skip the automatic generation of these APIs - if ( - op_info.infer_meta_func is None - and op_name not in PD_MANUAL_OP_LIST - ): + if self._need_skip(op_info, op_name): continue impl_str += self._gen_one_impl(op_info, op_name) body = impl_str diff --git a/paddle/fluid/ir/dialect/op_generator/vjp_interface_gen_op_list.py b/paddle/fluid/pir/dialect/op_generator/vjp_interface_gen_op_list.py similarity index 74% rename from paddle/fluid/ir/dialect/op_generator/vjp_interface_gen_op_list.py rename to paddle/fluid/pir/dialect/op_generator/vjp_interface_gen_op_list.py index 9707d6fb5f9a2..2bbce72200d0c 100644 --- a/paddle/fluid/ir/dialect/op_generator/vjp_interface_gen_op_list.py +++ b/paddle/fluid/pir/dialect/op_generator/vjp_interface_gen_op_list.py @@ -22,6 +22,7 @@ # remove this file and support Vjp methods # code gen. + vjp_interface_declare_gen_op_list = [ "tanh", "mean", @@ -34,15 +35,29 @@ "matmul", "erf", "multiply", - "subtract", "pow", "rsqrt", + "subtract", + "square", "dropout", + 'exp', + 'expand', + 'layer_norm', + 'reshape', + 'cast', + 'softmax', + 'silu', + 'elementwise_pow', + 'fused_softmax_mask_upper_triangle', + 'slice', + 'transpose', + 'slice_double', ] vjp_interface_implementation_gen_op_list = [ "tanh", "mean", "divide", + "sum", "add", "concat", "split", @@ -53,5 +68,18 @@ "subtract", "pow", "rsqrt", + "square", "dropout", + 'exp', + 'expand', + 'layer_norm', + 'reshape', + 'cast', + 'softmax', + 'silu', + 'elementwise_pow', + 'fused_softmax_mask_upper_triangle', + 'slice', + 'transpose', + 'slice_double', ] diff --git a/paddle/fluid/ir/dialect/paddle_dialect/CMakeLists.txt b/paddle/fluid/pir/dialect/operator/CMakeLists.txt similarity index 100% rename from paddle/fluid/ir/dialect/paddle_dialect/CMakeLists.txt rename to paddle/fluid/pir/dialect/operator/CMakeLists.txt diff --git a/paddle/fluid/pir/dialect/operator/interface/CMakeLists.txt b/paddle/fluid/pir/dialect/operator/interface/CMakeLists.txt new file mode 100644 index 0000000000000..a6496585e7790 --- /dev/null +++ b/paddle/fluid/pir/dialect/operator/interface/CMakeLists.txt @@ -0,0 +1,7 @@ +# All source files of pd_op_dialect, except for the source file of op, which is generated in the compilation directory. +file(GLOB PD_INTERFACE_SRCS "*.cc") + +cc_library( + pd_interface + SRCS ${PD_INTERFACE_SRCS} + DEPS pir_core phi_utils) diff --git a/paddle/fluid/ir/dialect/paddle_dialect/interface/infermeta.h b/paddle/fluid/pir/dialect/operator/interface/infermeta.h similarity index 77% rename from paddle/fluid/ir/dialect/paddle_dialect/interface/infermeta.h rename to paddle/fluid/pir/dialect/operator/interface/infermeta.h index ba3d54c59439b..958d2df369ed9 100644 --- a/paddle/fluid/ir/dialect/paddle_dialect/interface/infermeta.h +++ b/paddle/fluid/pir/dialect/operator/interface/infermeta.h @@ -13,13 +13,14 @@ // limitations under the License. #pragma once -#include "paddle/ir/core/op_base.h" #include "paddle/phi/core/infermeta_utils.h" +#include "paddle/pir/core/op_base.h" namespace paddle { namespace dialect { -class InferMetaInterface : public ir::OpInterfaceBase { +class InferMetaInterface : public pir::OpInterfaceBase { public: + /// Defined these methods with the interface. struct Concept { explicit Concept(void (*infer_meta)(phi::InferMetaContext *)) : infer_meta_(infer_meta) {} @@ -28,15 +29,16 @@ class InferMetaInterface : public ir::OpInterfaceBase { template struct Model : public Concept { - static void InferMeta(phi::InferMetaContext *infer_meta) { + static inline void InferMeta(phi::InferMetaContext *infer_meta) { return ConcreteOp::InferMeta(infer_meta); } Model() : Concept(InferMeta) {} }; - InferMetaInterface(ir::Operation *op, Concept *impl) - : ir::OpInterfaceBase(op), impl_(impl) {} + /// Constructor + InferMetaInterface(pir::Operation *op, Concept *impl) + : pir::OpInterfaceBase(op), impl_(impl) {} void InferMeta(phi::InferMetaContext *infer_meta) { impl_->infer_meta_(infer_meta); diff --git a/paddle/fluid/ir/dialect/paddle_dialect/interface/interface.cc b/paddle/fluid/pir/dialect/operator/interface/interface.cc similarity index 79% rename from paddle/fluid/ir/dialect/paddle_dialect/interface/interface.cc rename to paddle/fluid/pir/dialect/operator/interface/interface.cc index 12b14de308640..92b3bf0ba2168 100644 --- a/paddle/fluid/ir/dialect/paddle_dialect/interface/interface.cc +++ b/paddle/fluid/pir/dialect/operator/interface/interface.cc @@ -12,9 +12,9 @@ // See the License for the specific language governing permissions and // limitations under the License. -#include "paddle/fluid/ir/dialect/paddle_dialect/interface/infermeta.h" -#include "paddle/fluid/ir/dialect/paddle_dialect/interface/op_yaml_info.h" -#include "paddle/fluid/ir/dialect/paddle_dialect/interface/vjp.h" +#include "paddle/fluid/pir/dialect/operator/interface/infermeta.h" +#include "paddle/fluid/pir/dialect/operator/interface/op_yaml_info.h" +#include "paddle/fluid/pir/dialect/operator/interface/vjp.h" IR_DEFINE_EXPLICIT_TYPE_ID(paddle::dialect::InferMetaInterface) IR_DEFINE_EXPLICIT_TYPE_ID(paddle::dialect::OpYamlInfoInterface) diff --git a/paddle/fluid/ir/dialect/paddle_dialect/interface/op_yaml_info.h b/paddle/fluid/pir/dialect/operator/interface/op_yaml_info.h similarity index 83% rename from paddle/fluid/ir/dialect/paddle_dialect/interface/op_yaml_info.h rename to paddle/fluid/pir/dialect/operator/interface/op_yaml_info.h index 7663fb2029a43..33011f5613eb5 100644 --- a/paddle/fluid/ir/dialect/paddle_dialect/interface/op_yaml_info.h +++ b/paddle/fluid/pir/dialect/operator/interface/op_yaml_info.h @@ -14,8 +14,8 @@ #pragma once -#include "paddle/fluid/ir/dialect/paddle_dialect/utils/op_yaml_info_util.h" -#include "paddle/ir/core/op_base.h" +#include "paddle/fluid/pir/dialect/operator/utils/op_yaml_info_util.h" +#include "paddle/pir/core/op_base.h" using OpInfoTuple = std::tuple, std::vector, @@ -25,7 +25,7 @@ using OpInfoTuple = std::tuple, namespace paddle { namespace dialect { -class OpYamlInfoInterface : public ir::OpInterfaceBase { +class OpYamlInfoInterface : public pir::OpInterfaceBase { public: struct Concept { explicit Concept(OpInfoTuple (*get_op_info)()) @@ -40,8 +40,8 @@ class OpYamlInfoInterface : public ir::OpInterfaceBase { Model() : Concept(GetOpInfo) {} }; - OpYamlInfoInterface(ir::Operation *op, Concept *impl) - : ir::OpInterfaceBase(op), impl_(impl) {} + OpYamlInfoInterface(pir::Operation *op, Concept *impl) + : pir::OpInterfaceBase(op), impl_(impl) {} OpInfoTuple GetOpInfo() { return impl_->get_op_info_(); } diff --git a/paddle/fluid/ir/dialect/paddle_dialect/interface/vjp.h b/paddle/fluid/pir/dialect/operator/interface/vjp.h similarity index 62% rename from paddle/fluid/ir/dialect/paddle_dialect/interface/vjp.h rename to paddle/fluid/pir/dialect/operator/interface/vjp.h index a373cd0bacca4..56c814db89088 100644 --- a/paddle/fluid/ir/dialect/paddle_dialect/interface/vjp.h +++ b/paddle/fluid/pir/dialect/operator/interface/vjp.h @@ -13,29 +13,29 @@ // limitations under the License. #pragma once -#include "paddle/ir/core/op_base.h" +#include "paddle/pir/core/op_base.h" namespace paddle { namespace dialect { -class VjpInterface : public ir::OpInterfaceBase { +class VjpInterface : public pir::OpInterfaceBase { public: struct Concept { - explicit Concept(std::vector> (*vjp)( - ir::Operation* op, - const std::vector>& out_grads, + explicit Concept(std::vector> (*vjp)( + pir::Operation* op, + const std::vector>& out_grads, const std::vector>& stop_gradients)) : vjp_(vjp) {} - std::vector> (*vjp_)( - ir::Operation* op, - const std::vector>& out_grads, + std::vector> (*vjp_)( + pir::Operation* op, + const std::vector>& out_grads, const std::vector>& stop_gradients); }; template struct Model : public Concept { - static std::vector> Vjp( - ir::Operation* op, - const std::vector>& out_grads, + static std::vector> Vjp( + pir::Operation* op, + const std::vector>& out_grads, const std::vector>& stop_gradients) { return ConcreteOp::Vjp(op, out_grads, stop_gradients); } @@ -43,12 +43,12 @@ class VjpInterface : public ir::OpInterfaceBase { Model() : Concept(Vjp) {} }; - VjpInterface(ir::Operation* op, Concept* impl) - : ir::OpInterfaceBase(op), impl_(impl) {} + VjpInterface(pir::Operation* op, Concept* impl) + : pir::OpInterfaceBase(op), impl_(impl) {} - std::vector> Vjp( - ir::Operation* op, - const std::vector>& out_grads, + std::vector> Vjp( + pir::Operation* op, + const std::vector>& out_grads, const std::vector>& stop_gradients) { return impl_->vjp_(op, out_grads, stop_gradients); } diff --git a/paddle/fluid/ir/dialect/paddle_dialect/ir/.gitignore b/paddle/fluid/pir/dialect/operator/ir/.gitignore similarity index 100% rename from paddle/fluid/ir/dialect/paddle_dialect/ir/.gitignore rename to paddle/fluid/pir/dialect/operator/ir/.gitignore diff --git a/paddle/fluid/ir/dialect/paddle_dialect/ir/CMakeLists.txt b/paddle/fluid/pir/dialect/operator/ir/CMakeLists.txt similarity index 83% rename from paddle/fluid/ir/dialect/paddle_dialect/ir/CMakeLists.txt rename to paddle/fluid/pir/dialect/operator/ir/CMakeLists.txt index 64e3e982133be..71df1b6811bf7 100644 --- a/paddle/fluid/ir/dialect/paddle_dialect/ir/CMakeLists.txt +++ b/paddle/fluid/pir/dialect/operator/ir/CMakeLists.txt @@ -1,12 +1,12 @@ set(PD_DIALECT_BINARY_DIR - "${PADDLE_BINARY_DIR}/paddle/fluid/ir/dialect/paddle_dialect/ir") + "${PADDLE_BINARY_DIR}/paddle/fluid/pir/dialect/operator/ir") -# Generate pd_dialect files defining op using op_gen_file +# Generate pd_op_dialect files defining op using op_gen_file set(op_gen_parsed_yaml_file ${PADDLE_SOURCE_DIR}/paddle/fluid/operators/generator/parse_op.py) set(op_gen_file - ${PADDLE_SOURCE_DIR}/paddle/fluid/ir/dialect/op_generator/op_gen.py) + ${PADDLE_SOURCE_DIR}/paddle/fluid/pir/dialect/op_generator/op_gen.py) set(op_compat_yaml_file ${PADDLE_SOURCE_DIR}/paddle/phi/api/yaml/op_compat.yaml) set(op_forward_yaml_file1 ${PADDLE_SOURCE_DIR}/paddle/fluid/operators/generator/parsed_ops/ops.parsed.yaml @@ -28,23 +28,22 @@ set(fused_op_backward_yaml_file ) set(pd_op_forward_yaml_file - ${PADDLE_SOURCE_DIR}/paddle/fluid/ir/dialect/paddle_dialect/ir/pd_ops.yaml) + ${PADDLE_SOURCE_DIR}/paddle/fluid/pir/dialect/operator/ir/ops.yaml) set(pd_op_backward_yaml_file - ${PADDLE_SOURCE_DIR}/paddle/fluid/ir/dialect/paddle_dialect/ir/pd_ops_backward.yaml -) + ${PADDLE_SOURCE_DIR}/paddle/fluid/pir/dialect/operator/ir/ops_backward.yaml) set(parsed_op_dir - ${PADDLE_SOURCE_DIR}/paddle/fluid/ir/dialect/paddle_dialect/ir/generated) + ${PADDLE_SOURCE_DIR}/paddle/fluid/pir/dialect/operator/ir/generated) -set(op_yaml_file3 ${parsed_op_dir}/pd_ops.parsed.yaml) -set(op_yaml_file4 ${parsed_op_dir}/pd_ops_backward.parsed.yaml) +set(op_yaml_file3 ${parsed_op_dir}/ops.parsed.yaml) +set(op_yaml_file4 ${parsed_op_dir}/ops_backward.parsed.yaml) set(op_yaml_files ${op_forward_yaml_file1},${op_forward_yaml_file2},${op_backward_yaml_file1},${op_backward_yaml_file2},${fused_op_forward_yaml_file},${fused_op_backward_yaml_file},${op_yaml_file3},${op_yaml_file4} ) set(op_namespace paddle,dialect) -set(dialect_name pd) +set(dialect_name pd_op) set(op_header_file ${PD_DIALECT_BINARY_DIR}/pd_op.h) set(op_source_file ${PD_DIALECT_BINARY_DIR}/pd_op.cc) set(op_header_file_tmp ${op_header_file}.tmp) @@ -96,7 +95,7 @@ set(api_gen_yaml_files ${op_forward_yaml_file1},${op_forward_yaml_file2},${op_backward_yaml_file1},${op_backward_yaml_file2},${op_yaml_file3},${op_yaml_file4} ) set(api_gen_file - ${PADDLE_SOURCE_DIR}/paddle/fluid/ir/dialect/op_generator/api_gen.py) + ${PADDLE_SOURCE_DIR}/paddle/fluid/pir/dialect/op_generator/api_gen.py) set(api_header_file ${PD_DIALECT_BINARY_DIR}/pd_api.h) set(api_source_file ${PD_DIALECT_BINARY_DIR}/pd_api.cc) set(api_header_file_tmp ${api_header_file}.tmp) @@ -125,7 +124,7 @@ add_custom_command( VERBATIM) set(python_c_gen_file - ${PADDLE_SOURCE_DIR}/paddle/fluid/ir/dialect/op_generator/python_c_gen.py) + ${PADDLE_SOURCE_DIR}/paddle/fluid/pir/dialect/op_generator/python_c_gen.py) set(python_c_header_file ${PADDLE_SOURCE_DIR}/paddle/fluid/pybind/static_op_function.h) set(python_c_source_file @@ -160,7 +159,7 @@ add_custom_target(static_op_function_gen ALL DEPENDS ${python_c_header_file} ${python_c_source_file}) set(ops_api_gen_file - ${PADDLE_SOURCE_DIR}/paddle/fluid/ir/dialect/op_generator/ops_api_gen.py) + ${PADDLE_SOURCE_DIR}/paddle/fluid/pir/dialect/op_generator/ops_api_gen.py) set(ops_api_source_file ${PADDLE_SOURCE_DIR}/paddle/fluid/pybind/ops_api.cc) set(ops_api_source_file_tmp ${ops_api_source_file}.tmp) @@ -186,26 +185,26 @@ add_custom_command( add_custom_target(ops_api_gen ALL DEPENDS ${ops_api_source_file}) cc_library( - pd_dialect_core - SRCS pd_attribute.cc pd_type.cc pd_meta_tensor.cc + pd_op_dialect_core + SRCS op_attribute.cc op_type.cc meta_tensor.cc DEPS phi pd_interface pd_trait type_info) cc_library( - pd_dialect_op - SRCS ${op_source_file} pd_manual_op.cc - DEPS pd_dialect_core) + pd_op_dialect_op + SRCS ${op_source_file} manual_op.cc + DEPS pd_op_dialect_core) cc_library( api_builder SRCS api_builder.cc - DEPS ir_core) + DEPS pir_core) cc_library( - pd_dialect_api - SRCS ${api_source_file} pd_manual_api.cc - DEPS api_builder pd_dialect_op) + pd_op_dialect_api + SRCS ${api_source_file} manual_api.cc + DEPS api_builder pd_op_dialect_op) -target_include_directories(pd_dialect_api PRIVATE ${PD_DIALECT_BINARY_DIR}) +target_include_directories(pd_op_dialect_api PRIVATE ${PD_DIALECT_BINARY_DIR}) cc_library( - pd_dialect - SRCS pd_dialect.cc pd_manual_op_vjp.cc ${op_vjp_source_file} - DEPS pd_dialect_api param_to_variable primitive_vjp_experimental - pd_dialect_utils op_yaml_info_parser) + pd_op_dialect + SRCS op_dialect.cc manual_op_vjp.cc ${op_vjp_source_file} + DEPS pd_op_dialect_api param_to_variable primitive_vjp_experimental + pd_op_dialect_utils op_yaml_info_parser) diff --git a/paddle/fluid/ir/dialect/paddle_dialect/ir/api_builder.cc b/paddle/fluid/pir/dialect/operator/ir/api_builder.cc similarity index 80% rename from paddle/fluid/ir/dialect/paddle_dialect/ir/api_builder.cc rename to paddle/fluid/pir/dialect/operator/ir/api_builder.cc index 0ded4ee1a5de8..893c664b78b08 100644 --- a/paddle/fluid/ir/dialect/paddle_dialect/ir/api_builder.cc +++ b/paddle/fluid/pir/dialect/operator/ir/api_builder.cc @@ -12,22 +12,22 @@ // See the License for the specific language governing permissions and // limitations under the License. -#include "paddle/fluid/ir/dialect/paddle_dialect/ir/api_builder.h" -#include "paddle/ir/core/enforce.h" -#include "paddle/ir/core/ir_context.h" +#include "paddle/fluid/pir/dialect/operator/ir/api_builder.h" +#include "paddle/pir/core/enforce.h" +#include "paddle/pir/core/ir_context.h" namespace paddle { namespace dialect { APIBuilder::APIBuilder() : builder_(nullptr) { - ctx_ = ir::IrContext::Instance(); + ctx_ = pir::IrContext::Instance(); } -void APIBuilder::SetProgram(ir::Program* program) { - builder_ = std::make_shared(ctx_, program->block()); +void APIBuilder::SetProgram(pir::Program* program) { + builder_ = std::make_shared(ctx_, program->block()); } -void APIBuilder::SetInsertionPoint(ir::Operation* op) { +void APIBuilder::SetInsertionPoint(pir::Operation* op) { IR_ENFORCE(builder_ != nullptr, "builder doesn't hold program, please call SetProgram for " "initialization."); diff --git a/paddle/fluid/ir/dialect/paddle_dialect/ir/api_builder.h b/paddle/fluid/pir/dialect/operator/ir/api_builder.h similarity index 78% rename from paddle/fluid/ir/dialect/paddle_dialect/ir/api_builder.h rename to paddle/fluid/pir/dialect/operator/ir/api_builder.h index 029c79c2110c0..a06f529d2c5be 100644 --- a/paddle/fluid/ir/dialect/paddle_dialect/ir/api_builder.h +++ b/paddle/fluid/pir/dialect/operator/ir/api_builder.h @@ -15,9 +15,9 @@ #pragma once #include -#include "paddle/ir/core/builder.h" -#include "paddle/ir/core/macros.h" -#include "paddle/ir/core/program.h" +#include "paddle/pir/core/builder.h" +#include "paddle/pir/core/macros.h" +#include "paddle/pir/core/program.h" namespace paddle { namespace dialect { @@ -30,25 +30,25 @@ class APIBuilder { static APIBuilder api_builder; return api_builder; } - void SetProgram(ir::Program* program); + void SetProgram(pir::Program* program); /// Set the insertion point to the specified operation, which will cause /// subsequent insertions to go right before it. - void SetInsertionPoint(ir::Operation* op); + void SetInsertionPoint(pir::Operation* op); void ResetInsertionPointToStart(); void ResetInsertionPointToEnd(); - std::shared_ptr GetBuilder() { return builder_; } + std::shared_ptr GetBuilder() { return builder_; } private: APIBuilder(); DISABLE_COPY_AND_ASSIGN(APIBuilder); - ir::IrContext* ctx_; - std::shared_ptr builder_; + pir::IrContext* ctx_; + std::shared_ptr builder_; }; } // namespace dialect diff --git a/paddle/fluid/ir/dialect/paddle_dialect/ir/pd_attribute_storage.h b/paddle/fluid/pir/dialect/operator/ir/attribute_storage.h similarity index 84% rename from paddle/fluid/ir/dialect/paddle_dialect/ir/pd_attribute_storage.h rename to paddle/fluid/pir/dialect/operator/ir/attribute_storage.h index 1877e5043fc65..68f066b009329 100644 --- a/paddle/fluid/ir/dialect/paddle_dialect/ir/pd_attribute_storage.h +++ b/paddle/fluid/pir/dialect/operator/ir/attribute_storage.h @@ -14,17 +14,17 @@ #pragma once -#include "paddle/ir/core/attribute.h" -#include "paddle/ir/core/attribute_base.h" -#include "paddle/ir/core/utils.h" #include "paddle/phi/common/data_type.h" #include "paddle/phi/common/int_array.h" #include "paddle/phi/common/layout.h" #include "paddle/phi/common/place.h" +#include "paddle/pir/core/attribute.h" +#include "paddle/pir/core/attribute_base.h" +#include "paddle/pir/core/utils.h" namespace paddle { namespace dialect { -struct IntArrayAttributeStorage : public ir::AttributeStorage { +struct IntArrayAttributeStorage : public pir::AttributeStorage { using ParamKey = phi::IntArray; explicit IntArrayAttributeStorage(const ParamKey &key) { data_ = key; } @@ -36,9 +36,9 @@ struct IntArrayAttributeStorage : public ir::AttributeStorage { static std::size_t HashValue(const ParamKey &key) { size_t hash_value = 0; hash_value = - ir::hash_combine(hash_value, std::hash()(key.FromTensor())); + pir::hash_combine(hash_value, std::hash()(key.FromTensor())); for (auto value : key.GetData()) { - hash_value = ir::hash_combine(hash_value, std::hash()(value)); + hash_value = pir::hash_combine(hash_value, std::hash()(value)); } return hash_value; } @@ -54,7 +54,7 @@ struct IntArrayAttributeStorage : public ir::AttributeStorage { phi::IntArray data_; }; -struct DataTypeAttributeStorage : public ir::AttributeStorage { +struct DataTypeAttributeStorage : public pir::AttributeStorage { using ParamKey = phi::DataType; explicit DataTypeAttributeStorage(const ParamKey &key) { data_ = key; } @@ -75,7 +75,7 @@ struct DataTypeAttributeStorage : public ir::AttributeStorage { phi::DataType data_; }; -struct PlaceAttributeStorage : public ir::AttributeStorage { +struct PlaceAttributeStorage : public pir::AttributeStorage { using ParamKey = phi::Place; explicit PlaceAttributeStorage(const ParamKey &key) { data_ = key; } @@ -94,7 +94,7 @@ struct PlaceAttributeStorage : public ir::AttributeStorage { phi::Place data_; }; -struct DataLayoutAttributeStorage : public ir::AttributeStorage { +struct DataLayoutAttributeStorage : public pir::AttributeStorage { using ParamKey = phi::DataLayout; explicit DataLayoutAttributeStorage(const ParamKey &key) { data_ = key; } diff --git a/paddle/fluid/pir/dialect/operator/ir/manual_api.cc b/paddle/fluid/pir/dialect/operator/ir/manual_api.cc new file mode 100644 index 0000000000000..05bd226bacba4 --- /dev/null +++ b/paddle/fluid/pir/dialect/operator/ir/manual_api.cc @@ -0,0 +1,74 @@ +// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/fluid/pir/dialect/operator/ir/manual_api.h" +#include "paddle/fluid/pir/dialect/operator/ir/api_builder.h" +#include "paddle/fluid/pir/dialect/operator/ir/op_type.h" +#include "paddle/fluid/pir/dialect/operator/ir/pd_api.h" +#include "paddle/fluid/pir/dialect/operator/ir/pd_op.h" +#include "paddle/pir/core/builtin_op.h" + +namespace paddle { +namespace dialect { + +pir::OpResult builtin_combine(std::vector x) { + auto combine_op = + APIBuilder::Instance().GetBuilder()->Build(x); + return combine_op.out(); +} + +pir::OpResult get_parameter(const std::string& name, + phi::DataType dtype, + const std::vector& shape) { + phi::LoD lod; + size_t offset{0}; + pir::Type out_dense_tensor_type = paddle::dialect::DenseTensorType::get( + pir::IrContext::Instance(), + TransToIrDataType(dtype), + phi::DDim(shape.data(), shape.size()), + phi::DataLayout::UNDEFINED, + lod, + offset); + pir::GetParameterOp get_parameter_op = + APIBuilder::Instance().GetBuilder()->Build( + name, out_dense_tensor_type); + return get_parameter_op.result(0); +} + +void set_parameter(pir::OpResult parameter, const std::string& name) { + APIBuilder::Instance().GetBuilder()->Build(parameter, + name); +} + +pir::OpResult embedding_grad(pir::OpResult x, + pir::OpResult weight, + pir::OpResult out_grad, + int64_t padding_idx, + bool sparse) { + if (weight.type().isa()) { + if (sparse) { + return paddle::dialect::embedding_grad_sparse( + x, weight, out_grad, padding_idx, sparse); + } else { + return paddle::dialect::embedding_grad_dense( + x, weight, out_grad, padding_idx, sparse); + } + } else { + PADDLE_THROW(phi::errors::Unimplemented( + "Now we do not support sparse weight embedding_grad.")); + } +} + +} // namespace dialect +} // namespace paddle diff --git a/paddle/fluid/pir/dialect/operator/ir/manual_api.h b/paddle/fluid/pir/dialect/operator/ir/manual_api.h new file mode 100644 index 0000000000000..8c737a52b3aa7 --- /dev/null +++ b/paddle/fluid/pir/dialect/operator/ir/manual_api.h @@ -0,0 +1,40 @@ +// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include + +#include "paddle/phi/common/data_type.h" +#include "paddle/phi/common/place.h" +#include "paddle/pir/core/op_result.h" + +namespace paddle { +namespace dialect { +pir::OpResult builtin_combine(std::vector x); + +pir::OpResult get_parameter(const std::string& name, + phi::DataType dtype, + const std::vector& shape); + +void set_parameter(pir::OpResult parameter, const std::string& name); + +pir::OpResult embedding_grad(pir::OpResult x, + pir::OpResult weight, + pir::OpResult out_grad, + int64_t padding_idx = -1, + bool sparse = false); + +} // namespace dialect +} // namespace paddle diff --git a/paddle/fluid/ir/dialect/paddle_dialect/ir/pd_manual_op.cc b/paddle/fluid/pir/dialect/operator/ir/manual_op.cc similarity index 82% rename from paddle/fluid/ir/dialect/paddle_dialect/ir/pd_manual_op.cc rename to paddle/fluid/pir/dialect/operator/ir/manual_op.cc index 058a08a384d2d..3ee3bec97cd89 100644 --- a/paddle/fluid/ir/dialect/paddle_dialect/ir/pd_manual_op.cc +++ b/paddle/fluid/pir/dialect/operator/ir/manual_op.cc @@ -12,20 +12,21 @@ // See the License for the specific language governing permissions and // limitations under the License. -#include "paddle/fluid/ir/dialect/paddle_dialect/ir/pd_manual_op.h" -#include "paddle/fluid/ir/dialect/paddle_dialect/ir/pd_attribute.h" -#include "paddle/fluid/ir/dialect/paddle_dialect/ir/pd_op.h" -#include "paddle/fluid/ir/dialect/paddle_dialect/ir/pd_type.h" -#include "paddle/ir/core/builtin_attribute.h" -#include "paddle/ir/core/builtin_op.h" -#include "paddle/ir/core/builtin_type.h" -#include "paddle/ir/core/ir_context.h" +#include "paddle/fluid/pir/dialect/operator/ir/manual_op.h" +#include "paddle/fluid/pir/dialect/operator/ir/op_attribute.h" +#include "paddle/fluid/pir/dialect/operator/ir/op_type.h" +#include "paddle/fluid/pir/dialect/operator/ir/pd_op.h" #include "paddle/phi/api/lib/utils/allocator.h" #include "paddle/phi/core/dense_tensor.h" #include "paddle/phi/core/enforce.h" #include "paddle/phi/infermeta/backward.h" #include "paddle/phi/infermeta/fusion.h" #include "paddle/phi/infermeta/multiary.h" +#include "paddle/phi/infermeta/unary.h" +#include "paddle/pir/core/builtin_attribute.h" +#include "paddle/pir/core/builtin_op.h" +#include "paddle/pir/core/builtin_type.h" +#include "paddle/pir/core/ir_context.h" namespace paddle { namespace dialect { @@ -33,7 +34,7 @@ namespace dialect { OpInfoTuple AddNOp::GetOpInfo() { std::vector inputs = { OpInputInfo("inputs", - "ir::VectorType", + "pir::VectorType", false, false, false, @@ -57,7 +58,8 @@ void AddNOp::Verify() { 1u, phi::errors::PreconditionNotMet( "The size %d of inputs must be equal to 1.", input_size)); - if (auto vec_type = (*this)->operand(0).type().dyn_cast()) { + if (auto vec_type = + (*this)->operand(0).type().dyn_cast()) { for (size_t i = 0; i < vec_type.size(); ++i) { PADDLE_ENFORCE(vec_type[i].isa() || vec_type[i].isa(), @@ -96,17 +98,17 @@ void AddNOp::Verify() { VLOG(4) << "End Verifying for: AddNOp."; } -void AddNOp::Build(ir::Builder &builder, // NOLINT - ir::OperationArgument &argument, // NOLINT - ir::OpResult inputs) { +void AddNOp::Build(pir::Builder &builder, // NOLINT + pir::OperationArgument &argument, // NOLINT + pir::OpResult inputs) { VLOG(4) << "Builder construction inputs"; - std::vector argument_inputs = {inputs}; + std::vector argument_inputs = {inputs}; argument.AddOperands(argument_inputs.begin(), argument_inputs.end()); VLOG(4) << "Builder construction attributes"; VLOG(4) << "Builder construction outputs"; - ir::VectorType x = inputs.type().dyn_cast(); + pir::VectorType x = inputs.type().dyn_cast(); (void)x; std::vector vec_dense_x; @@ -137,9 +139,9 @@ void AddNOp::Build(ir::Builder &builder, // NOLINT phi::AddNInferMeta(meta_x, &meta_out); - std::vector argument_outputs; - ir::Type out_dense_tensor_type = paddle::dialect::DenseTensorType::get( - ir::IrContext::Instance(), + std::vector argument_outputs; + pir::Type out_dense_tensor_type = paddle::dialect::DenseTensorType::get( + pir::IrContext::Instance(), TransToIrDataType(dense_out.dtype()), dense_out.dims(), dense_out.layout(), @@ -158,7 +160,7 @@ OpInfoTuple AddN_Op::GetOpInfo() { std::vector inputs = { paddle::dialect::OpInputInfo( "inputs", - "ir::VectorType", + "pir::VectorType", false, false, false, @@ -172,17 +174,17 @@ OpInfoTuple AddN_Op::GetOpInfo() { return std::make_tuple(inputs, attributes, outputs, run_time_info, "add_n_"); } -void AddN_Op::Build(ir::Builder &builder, - ir::OperationArgument &argument, - ir::OpResult inputs_) { +void AddN_Op::Build(pir::Builder &builder, + pir::OperationArgument &argument, + pir::OpResult inputs_) { VLOG(4) << "Builder construction inputs"; - std::vector argument_inputs = {inputs_}; + std::vector argument_inputs = {inputs_}; argument.AddOperands(argument_inputs.begin(), argument_inputs.end()); VLOG(4) << "Builder construction attributes"; VLOG(4) << "Builder construction outputs"; - ir::VectorType inputs = inputs_.type().dyn_cast(); + pir::VectorType inputs = inputs_.type().dyn_cast(); std::vector vec_dense_inputs; for (size_t i = 0; i < static_cast(inputs.size()); i++) { vec_dense_inputs.push_back(phi::DenseTensor( @@ -213,9 +215,9 @@ void AddN_Op::Build(ir::Builder &builder, phi::AddNInferMeta(meta_inputs, &meta_out); - std::vector argument_outputs; - ir::Type out_dense_tensor_type = paddle::dialect::DenseTensorType::get( - ir::IrContext::Instance(), + std::vector argument_outputs; + pir::Type out_dense_tensor_type = paddle::dialect::DenseTensorType::get( + pir::IrContext::Instance(), paddle::dialect::TransToIrDataType(dense_out.dtype()), dense_out.dims(), dense_out.layout(), @@ -236,7 +238,7 @@ void AddN_Op::Verify() { phi::errors::PreconditionNotMet( "The size %d of inputs must be equal to 1.", input_size)); if (auto vec_type = - (*this)->operand_source(0).type().dyn_cast()) { + (*this)->operand_source(0).type().dyn_cast()) { for (size_t i = 0; i < vec_type.size(); ++i) { PADDLE_ENFORCE(vec_type[i].isa() || vec_type[i].isa(), @@ -285,7 +287,7 @@ OpInfoTuple AddNWithKernelOp::GetOpInfo() { std::vector inputs = { paddle::dialect::OpInputInfo( "inputs", - "ir::VectorType", + "pir::VectorType", false, false, false, @@ -300,17 +302,17 @@ OpInfoTuple AddNWithKernelOp::GetOpInfo() { inputs, attributes, outputs, run_time_info, "add_n_with_kernel"); } -void AddNWithKernelOp::Build(ir::Builder &builder, - ir::OperationArgument &argument, - ir::OpResult inputs_) { +void AddNWithKernelOp::Build(pir::Builder &builder, + pir::OperationArgument &argument, + pir::OpResult inputs_) { VLOG(4) << "Builder construction inputs"; - std::vector argument_inputs = {inputs_}; + std::vector argument_inputs = {inputs_}; argument.AddOperands(argument_inputs.begin(), argument_inputs.end()); VLOG(4) << "Builder construction attributes"; VLOG(4) << "Builder construction outputs"; - ir::VectorType inputs = inputs_.type().dyn_cast(); + pir::VectorType inputs = inputs_.type().dyn_cast(); std::vector vec_dense_inputs; for (size_t i = 0; i < static_cast(inputs.size()); i++) { vec_dense_inputs.push_back(phi::DenseTensor( @@ -341,9 +343,9 @@ void AddNWithKernelOp::Build(ir::Builder &builder, phi::AddNInferMeta(meta_inputs, &meta_out); - std::vector argument_outputs; - ir::Type out_dense_tensor_type = paddle::dialect::DenseTensorType::get( - ir::IrContext::Instance(), + std::vector argument_outputs; + pir::Type out_dense_tensor_type = paddle::dialect::DenseTensorType::get( + pir::IrContext::Instance(), paddle::dialect::TransToIrDataType(dense_out.dtype()), dense_out.dims(), dense_out.layout(), @@ -365,7 +367,7 @@ void AddNWithKernelOp::Verify() { phi::errors::PreconditionNotMet( "The size %d of inputs must be equal to 1.", input_size)); if (auto vec_type = - (*this)->operand_source(0).type().dyn_cast()) { + (*this)->operand_source(0).type().dyn_cast()) { for (size_t i = 0; i < vec_type.size(); ++i) { PADDLE_ENFORCE(vec_type[i].isa() || vec_type[i].isa(), @@ -426,9 +428,9 @@ OpInfoTuple FusedGemmEpilogueOp::GetOpInfo() { false, false)}; std::vector attributes = { - paddle::dialect::OpAttributeInfo("trans_x", "ir::BoolAttribute", ""), - paddle::dialect::OpAttributeInfo("trans_y", "ir::BoolAttribute", ""), - paddle::dialect::OpAttributeInfo("activation", "ir::StrAttribute", "")}; + paddle::dialect::OpAttributeInfo("trans_x", "pir::BoolAttribute", ""), + paddle::dialect::OpAttributeInfo("trans_y", "pir::BoolAttribute", ""), + paddle::dialect::OpAttributeInfo("activation", "pir::StrAttribute", "")}; std::vector outputs = { paddle::dialect::OpOutputInfo( "out", "paddle::dialect::DenseTensorType", false, false), @@ -448,32 +450,44 @@ OpInfoTuple FusedGemmEpilogueOp::GetOpInfo() { inputs, attributes, outputs, run_time_info, "fused_gemm_epilogue"); } -void FusedGemmEpilogueOp::Build(ir::Builder &builder, - ir::OperationArgument &argument, - ir::OpResult x_, - ir::OpResult y_, - ir::OpResult bias_, - ir::AttributeMap attributes) { - bool trans_x = attributes.at("trans_x").dyn_cast().data(); - - bool trans_y = attributes.at("trans_y").dyn_cast().data(); - +void FusedGemmEpilogueOp::Build(pir::Builder &builder, + pir::OperationArgument &argument, + pir::OpResult x_, + pir::OpResult y_, + pir::OpResult bias_, + pir::AttributeMap attributes) { + PADDLE_ENFORCE( + attributes.find("trans_x") != attributes.end(), + phi::errors::NotFound( + "'trans_x' Attribute is expected for FusedGemmEpilogueOp")); + bool trans_x = attributes.at("trans_x").dyn_cast().data(); + + PADDLE_ENFORCE( + attributes.find("trans_y") != attributes.end(), + phi::errors::NotFound( + "'trans_y' Attribute is expected for FusedGemmEpilogueOp")); + bool trans_y = attributes.at("trans_y").dyn_cast().data(); + + PADDLE_ENFORCE( + attributes.find("activation") != attributes.end(), + phi::errors::NotFound( + "'activation' Attribute is expected for FusedGemmEpilogueOp")); std::string activation = - attributes.at("activation").dyn_cast().AsString(); + attributes.at("activation").dyn_cast().AsString(); VLOG(4) << "Builder construction inputs"; - std::vector argument_inputs = {x_, y_, bias_}; + std::vector argument_inputs = {x_, y_, bias_}; argument.AddOperands(argument_inputs.begin(), argument_inputs.end()); VLOG(4) << "Builder construction attributes"; - ir::Attribute attr_trans_x = - ir::BoolAttribute::get(ir::IrContext::Instance(), trans_x); + pir::Attribute attr_trans_x = + pir::BoolAttribute::get(pir::IrContext::Instance(), trans_x); argument.AddAttribute("trans_x", attr_trans_x); - ir::Attribute attr_trans_y = - ir::BoolAttribute::get(ir::IrContext::Instance(), trans_y); + pir::Attribute attr_trans_y = + pir::BoolAttribute::get(pir::IrContext::Instance(), trans_y); argument.AddAttribute("trans_y", attr_trans_y); - ir::Attribute attr_activation = - ir::StrAttribute::get(ir::IrContext::Instance(), activation); + pir::Attribute attr_activation = + pir::StrAttribute::get(pir::IrContext::Instance(), activation); argument.AddAttribute("activation", attr_activation); VLOG(4) << "Builder construction outputs"; @@ -540,9 +554,9 @@ void FusedGemmEpilogueOp::Build(ir::Builder &builder, &meta_out, activation == "none" ? nullptr : &meta_reserve_space); - std::vector argument_outputs; - ir::Type out_dense_tensor_type = paddle::dialect::DenseTensorType::get( - ir::IrContext::Instance(), + std::vector argument_outputs; + pir::Type out_dense_tensor_type = paddle::dialect::DenseTensorType::get( + pir::IrContext::Instance(), paddle::dialect::TransToIrDataType(dense_out.dtype()), dense_out.dims(), dense_out.layout(), @@ -550,11 +564,11 @@ void FusedGemmEpilogueOp::Build(ir::Builder &builder, dense_out.offset()); argument_outputs.push_back(out_dense_tensor_type); - ir::Type reserve_space_dense_tensor_type = + pir::Type reserve_space_dense_tensor_type = activation == "none" - ? ir::Type() + ? pir::Type() : paddle::dialect::DenseTensorType::get( - ir::IrContext::Instance(), + pir::IrContext::Instance(), paddle::dialect::TransToIrDataType(dense_reserve_space.dtype()), dense_reserve_space.dims(), dense_reserve_space.layout(), @@ -599,15 +613,15 @@ void FusedGemmEpilogueOp::Verify() { { auto &attributes = this->attributes(); PADDLE_ENFORCE(attributes.count("trans_x") > 0 && - attributes.at("trans_x").isa(), + attributes.at("trans_x").isa(), phi::errors::PreconditionNotMet( "Type of attribute: trans_x is not right.")); PADDLE_ENFORCE(attributes.count("trans_y") > 0 && - attributes.at("trans_y").isa(), + attributes.at("trans_y").isa(), phi::errors::PreconditionNotMet( "Type of attribute: trans_y is not right.")); PADDLE_ENFORCE(attributes.count("activation") > 0 && - attributes.at("activation").isa(), + attributes.at("activation").isa(), phi::errors::PreconditionNotMet( "Type of attribute: activation is not right.")); } @@ -659,10 +673,10 @@ OpInfoTuple FusedGemmEpilogueGradOp::GetOpInfo() { false, false)}; std::vector attributes = { - paddle::dialect::OpAttributeInfo("trans_x", "ir::BoolAttribute", ""), - paddle::dialect::OpAttributeInfo("trans_y", "ir::BoolAttribute", ""), + paddle::dialect::OpAttributeInfo("trans_x", "pir::BoolAttribute", ""), + paddle::dialect::OpAttributeInfo("trans_y", "pir::BoolAttribute", ""), paddle::dialect::OpAttributeInfo( - "activation_grad", "ir::StrAttribute", "")}; + "activation_grad", "pir::StrAttribute", "")}; std::vector outputs = { paddle::dialect::OpOutputInfo( "x_grad", "paddle::dialect::DenseTensorType", false, false), @@ -689,34 +703,46 @@ OpInfoTuple FusedGemmEpilogueGradOp::GetOpInfo() { inputs, attributes, outputs, run_time_info, "fused_gemm_epilogue_grad"); } -void FusedGemmEpilogueGradOp::Build(ir::Builder &builder, - ir::OperationArgument &argument, - ir::OpResult x_, - ir::OpResult y_, - ir::OpResult reserve_space_, - ir::OpResult out_grad_, - ir::AttributeMap attributes) { - bool trans_x = attributes.at("trans_x").dyn_cast().data(); - - bool trans_y = attributes.at("trans_y").dyn_cast().data(); - +void FusedGemmEpilogueGradOp::Build(pir::Builder &builder, + pir::OperationArgument &argument, + pir::OpResult x_, + pir::OpResult y_, + pir::OpResult reserve_space_, + pir::OpResult out_grad_, + pir::AttributeMap attributes) { + PADDLE_ENFORCE( + attributes.find("trans_x") != attributes.end(), + phi::errors::NotFound( + "'trans_x' Attribute is expected for FusedGemmEpilogueGradOp")); + bool trans_x = attributes.at("trans_x").dyn_cast().data(); + + PADDLE_ENFORCE( + attributes.find("trans_y") != attributes.end(), + phi::errors::NotFound( + "'trans_y' Attribute is expected for FusedGemmEpilogueGradOp")); + bool trans_y = attributes.at("trans_y").dyn_cast().data(); + + PADDLE_ENFORCE( + attributes.find("activation_grad") != attributes.end(), + phi::errors::NotFound("'activation_grad' Attribute is expected for" + "FusedGemmEpilogueGradOp")); std::string activation_grad = - attributes.at("activation_grad").dyn_cast().AsString(); + attributes.at("activation_grad").dyn_cast().AsString(); VLOG(4) << "Builder construction inputs"; - std::vector argument_inputs = { + std::vector argument_inputs = { x_, y_, reserve_space_, out_grad_}; argument.AddOperands(argument_inputs.begin(), argument_inputs.end()); VLOG(4) << "Builder construction attributes"; - ir::Attribute attr_trans_x = - ir::BoolAttribute::get(ir::IrContext::Instance(), trans_x); + pir::Attribute attr_trans_x = + pir::BoolAttribute::get(pir::IrContext::Instance(), trans_x); argument.AddAttribute("trans_x", attr_trans_x); - ir::Attribute attr_trans_y = - ir::BoolAttribute::get(ir::IrContext::Instance(), trans_y); + pir::Attribute attr_trans_y = + pir::BoolAttribute::get(pir::IrContext::Instance(), trans_y); argument.AddAttribute("trans_y", attr_trans_y); - ir::Attribute attr_activation_grad = - ir::StrAttribute::get(ir::IrContext::Instance(), activation_grad); + pir::Attribute attr_activation_grad = + pir::StrAttribute::get(pir::IrContext::Instance(), activation_grad); argument.AddAttribute("activation_grad", attr_activation_grad); VLOG(4) << "Builder construction outputs"; @@ -809,9 +835,9 @@ void FusedGemmEpilogueGradOp::Build(ir::Builder &builder, &meta_y_grad, &meta_bias_grad); - std::vector argument_outputs; - ir::Type x_grad_dense_tensor_type = paddle::dialect::DenseTensorType::get( - ir::IrContext::Instance(), + std::vector argument_outputs; + pir::Type x_grad_dense_tensor_type = paddle::dialect::DenseTensorType::get( + pir::IrContext::Instance(), paddle::dialect::TransToIrDataType(dense_x_grad.dtype()), dense_x_grad.dims(), dense_x_grad.layout(), @@ -819,8 +845,8 @@ void FusedGemmEpilogueGradOp::Build(ir::Builder &builder, dense_x_grad.offset()); argument_outputs.push_back(x_grad_dense_tensor_type); - ir::Type y_grad_dense_tensor_type = paddle::dialect::DenseTensorType::get( - ir::IrContext::Instance(), + pir::Type y_grad_dense_tensor_type = paddle::dialect::DenseTensorType::get( + pir::IrContext::Instance(), paddle::dialect::TransToIrDataType(dense_y_grad.dtype()), dense_y_grad.dims(), dense_y_grad.layout(), @@ -828,8 +854,8 @@ void FusedGemmEpilogueGradOp::Build(ir::Builder &builder, dense_y_grad.offset()); argument_outputs.push_back(y_grad_dense_tensor_type); - ir::Type bias_grad_dense_tensor_type = paddle::dialect::DenseTensorType::get( - ir::IrContext::Instance(), + pir::Type bias_grad_dense_tensor_type = paddle::dialect::DenseTensorType::get( + pir::IrContext::Instance(), paddle::dialect::TransToIrDataType(dense_bias_grad.dtype()), dense_bias_grad.dims(), dense_bias_grad.layout(), @@ -851,7 +877,7 @@ const char *SplitGradOp::attributes_name[1] = {"axis"}; OpInfoTuple SplitGradOp::GetOpInfo() { std::vector inputs = { OpInputInfo("out_grad", - "ir::VectorType", + "pir::VectorType", false, false, false, @@ -879,23 +905,23 @@ OpInfoTuple SplitGradOp::GetOpInfo() { inputs, attributes, outputs, run_time_info, "split_grad"); } -void SplitGradOp::Build(ir::Builder &builder, - ir::OperationArgument &argument, - ir::OpResult out_grad_, +void SplitGradOp::Build(pir::Builder &builder, + pir::OperationArgument &argument, + pir::OpResult out_grad_, float axis) { // Generate scalar mutable attribute: axis paddle::dialect::FullOp full_axis_op = builder.Build( std::vector{1}, axis, phi::DataType::FLOAT32, phi::CPUPlace()); - ir::OpResult axis_ = full_axis_op->result(0); + pir::OpResult axis_ = full_axis_op->result(0); VLOG(4) << "Builder construction inputs"; - std::vector argument_inputs = {out_grad_, axis_}; + std::vector argument_inputs = {out_grad_, axis_}; argument.AddOperands(argument_inputs.begin(), argument_inputs.end()); VLOG(4) << "Builder construction attributes"; VLOG(4) << "Builder construction outputs"; - ir::VectorType out_grad = out_grad_.type().dyn_cast(); + pir::VectorType out_grad = out_grad_.type().dyn_cast(); std::vector vec_dense_out_grad; for (size_t i = 0; i < static_cast(out_grad.size()); i++) { vec_dense_out_grad.push_back(phi::DenseTensor( @@ -930,9 +956,9 @@ void SplitGradOp::Build(ir::Builder &builder, phi::ConcatInferMeta(meta_out_grad, axis, &meta_x_grad); - std::vector argument_outputs; - ir::Type x_grad_dense_tensor_type = paddle::dialect::DenseTensorType::get( - ir::IrContext::Instance(), + std::vector argument_outputs; + pir::Type x_grad_dense_tensor_type = paddle::dialect::DenseTensorType::get( + pir::IrContext::Instance(), paddle::dialect::TransToIrDataType(dense_x_grad.dtype()), dense_x_grad.dims(), dense_x_grad.layout(), @@ -942,18 +968,18 @@ void SplitGradOp::Build(ir::Builder &builder, argument.AddOutputs(argument_outputs.begin(), argument_outputs.end()); } -void SplitGradOp::Build(ir::Builder &builder, - ir::OperationArgument &argument, - ir::OpResult out_grad_, - ir::OpResult axis_) { +void SplitGradOp::Build(pir::Builder &builder, + pir::OperationArgument &argument, + pir::OpResult out_grad_, + pir::OpResult axis_) { VLOG(4) << "Builder construction inputs"; - std::vector argument_inputs = {out_grad_, axis_}; + std::vector argument_inputs = {out_grad_, axis_}; argument.AddOperands(argument_inputs.begin(), argument_inputs.end()); VLOG(4) << "Builder construction attributes"; VLOG(4) << "Builder construction outputs"; - ir::VectorType out_grad = out_grad_.type().dyn_cast(); + pir::VectorType out_grad = out_grad_.type().dyn_cast(); int axis = axis_.owner() ->dyn_cast() .attributes() @@ -995,9 +1021,9 @@ void SplitGradOp::Build(ir::Builder &builder, phi::ConcatInferMeta(meta_out_grad, axis, &meta_x_grad); - std::vector argument_outputs; - ir::Type x_grad_dense_tensor_type = paddle::dialect::DenseTensorType::get( - ir::IrContext::Instance(), + std::vector argument_outputs; + pir::Type x_grad_dense_tensor_type = paddle::dialect::DenseTensorType::get( + pir::IrContext::Instance(), TransToIrDataType(dense_x_grad.dtype()), dense_x_grad.dims(), dense_x_grad.layout(), @@ -1018,7 +1044,7 @@ void SplitGradOp::Verify() { phi::errors::PreconditionNotMet( "The size %d of inputs must be equal to 2.", input_size)); if (auto vec_type = - (*this)->operand_source(0).type().dyn_cast()) { + (*this)->operand_source(0).type().dyn_cast()) { for (size_t i = 0; i < vec_type.size(); ++i) { PADDLE_ENFORCE(vec_type[i].isa(), phi::errors::PreconditionNotMet( @@ -1064,29 +1090,29 @@ void SplitGradOp::InferMeta(phi::InferMetaContext *infer_meta) { fn(infer_meta); } -void IfOp::Build(ir::Builder &builder, // NOLINT - ir::OperationArgument &argument, // NOLINT - ir::OpResult cond, - std::vector &&output_types) { +void IfOp::Build(pir::Builder &builder, // NOLINT + pir::OperationArgument &argument, // NOLINT + pir::OpResult cond, + std::vector &&output_types) { argument.num_regions = 2; argument.AddOperand(cond); argument.output_types.swap(output_types); } -ir::Block *IfOp::true_block() { - ir::Region &true_region = (*this)->region(0); +pir::Block *IfOp::true_block() { + pir::Region &true_region = (*this)->region(0); if (true_region.empty()) true_region.emplace_back(); return true_region.front(); } -ir::Block *IfOp::false_block() { - ir::Region &false_region = (*this)->region(1); +pir::Block *IfOp::false_block() { + pir::Region &false_region = (*this)->region(1); if (false_region.empty()) false_region.emplace_back(); return false_region.front(); } -void IfOp::Print(ir::IrPrinter &printer) { +void IfOp::Print(pir::IrPrinter &printer) { auto &os = printer.os; auto op = operation(); printer.PrintOpResult(op); - os << " = pd.if"; + os << " = pd_op.if"; printer.PrintOpOperands(op); os << " -> "; printer.PrintOpReturnType(op); diff --git a/paddle/fluid/pir/dialect/operator/ir/manual_op.h b/paddle/fluid/pir/dialect/operator/ir/manual_op.h new file mode 100644 index 0000000000000..8cd8b9021858f --- /dev/null +++ b/paddle/fluid/pir/dialect/operator/ir/manual_op.h @@ -0,0 +1,204 @@ +// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifdef GET_MANUAL_OP_LIST +#undef GET_MANUAL_OP_LIST +paddle::dialect::AddNOp, paddle::dialect::SplitGradOp, paddle::dialect::IfOp + +#else + +#pragma once +#include + +#include "paddle/fluid/framework/infershape_utils.h" +#include "paddle/fluid/pir/dialect/operator/interface/infermeta.h" +#include "paddle/fluid/pir/dialect/operator/interface/op_yaml_info.h" +#include "paddle/fluid/pir/dialect/operator/trait/inplace.h" +#include "paddle/fluid/pir/dialect/operator/utils/op_yaml_info_util.h" +#include "paddle/fluid/pir/dialect/operator/utils/utils.h" +#include "paddle/phi/core/infermeta_utils.h" +#include "paddle/pir/core/builder.h" +#include "paddle/pir/core/ir_printer.h" +#include "paddle/pir/core/op_base.h" +#include "paddle/pir/core/operation_utils.h" + +namespace paddle { +namespace dialect { + +class AddNOp : public pir::Op { + public: + using Op::Op; + static const char *name() { return "pd_op.add_n"; } + static constexpr const char **attributes_name = nullptr; + static constexpr uint32_t attributes_num = 0; + static OpInfoTuple GetOpInfo(); + static void Build(pir::Builder &builder, // NOLINT + pir::OperationArgument &argument, // NOLINT + pir::OpResult inputs); + + void Verify(); + pir::Value inputs() { return operand_source(0); } + pir::OpResult out() { return result(0); } + static void InferMeta(phi::InferMetaContext *infer_meta); +}; + +class AddN_Op : public pir::Op { + public: + using Op::Op; + static const char *name() { return "pd_op.add_n_"; } + static constexpr const char **attributes_name = nullptr; + static constexpr uint32_t attributes_num = 0; + static OpInfoTuple GetOpInfo(); + static void Build(pir::Builder &builder, // NOLINT + pir::OperationArgument &argument, // NOLINT + pir::OpResult inputs_); + + void Verify(); + pir::Value inputs() { return operand_source(0); } + pir::OpResult out() { return result(0); } + + static void InferMeta(phi::InferMetaContext *infer_meta); +}; + +class AddNWithKernelOp : public pir::Op { + public: + using Op::Op; + static const char *name() { return "pd_op.add_n_with_kernel"; } + static constexpr const char **attributes_name = nullptr; + static constexpr uint32_t attributes_num = 0; + static OpInfoTuple GetOpInfo(); + static void Build(pir::Builder &builder, // NOLINT + pir::OperationArgument &argument, // NOLINT + pir::OpResult inputs_); + + void Verify(); + pir::Value inputs() { return operand_source(0); } + pir::OpResult out() { return result(0); } + + static void InferMeta(phi::InferMetaContext *infer_meta); +}; + +class FusedGemmEpilogueOp + : public pir::Op { + public: + using Op::Op; + static const char *name() { return "pd_op.fused_gemm_epilogue"; } + static const char *attributes_name[3]; + static constexpr uint32_t attributes_num = 3; + static OpInfoTuple GetOpInfo(); + + static void Build(pir::Builder &builder, // NOLINT + pir::OperationArgument &argument, // NOLINT + pir::OpResult x_, + pir::OpResult y_, + pir::OpResult bias_, + pir::AttributeMap attributes); + void Verify(); + pir::Value x() { return operand_source(0); } + pir::Value y() { return operand_source(1); } + pir::Value bias() { return operand_source(2); } + pir::OpResult out() { return result(0); } + pir::OpResult reserve_space() { return result(1); } + + static void InferMeta(phi::InferMetaContext *infer_meta); +}; + +class FusedGemmEpilogueGradOp + : public pir::Op { + public: + using Op::Op; + static const char *name() { return "pd_op.fused_gemm_epilogue_grad"; } + static const char *attributes_name[3]; + static constexpr uint32_t attributes_num = 3; + static OpInfoTuple GetOpInfo(); + + static void Build(pir::Builder &builder, // NOLINT + pir::OperationArgument &argument, // NOLINT + pir::OpResult x_, + pir::OpResult y_, + pir::OpResult reserve_space_, + pir::OpResult out_grad_, + pir::AttributeMap attributes); + void Verify(); + pir::Value x() { return operand_source(0); } + pir::Value y() { return operand_source(1); } + pir::Value reserve_space() { return operand_source(2); } + pir::Value out_grad() { return operand_source(3); } + pir::OpResult x_grad() { return result(0); } + pir::OpResult y_grad() { return result(1); } + pir::OpResult bias_grad() { return result(2); } + + static void InferMeta(phi::InferMetaContext *infer_meta); +}; + +class SplitGradOp : public pir::Op { + public: + using Op::Op; + static const char *name() { return "pd_op.split_grad"; } + static const char *attributes_name[1]; + static constexpr uint32_t attributes_num = 1; + static OpInfoTuple GetOpInfo(); + static void Build(pir::Builder &builder, // NOLINT + pir::OperationArgument &argument, // NOLINT + pir::OpResult x_, + float axis = 0); + static void Build(pir::Builder &builder, // NOLINT + pir::OperationArgument &argument, // NOLINT + pir::OpResult out_grad_, + pir::OpResult axis_); + + void Verify(); + pir::Value out_grad() { return operand_source(0); } + pir::Value axis() { return operand_source(1); } + pir::OpResult x_grad() { return result(0); } + static void InferMeta(phi::InferMetaContext *infer_meta); +}; + +class IfOp : public pir::Op { + public: + using Op::Op; + static const char *name() { return "pd_op.if"; } + static constexpr const char **attributes_name = nullptr; + static constexpr uint32_t attributes_num = 0; + static void Build(pir::Builder &builder, // NOLINT + pir::OperationArgument &argument, // NOLINT + pir::OpResult cond, + std::vector &&output_types); + pir::Value cond() { return operand_source(0); } + pir::Block *true_block(); + pir::Block *false_block(); + void Print(pir::IrPrinter &printer); // NOLINT + void Verify(); +}; + +} // namespace dialect +} // namespace paddle + +IR_DECLARE_EXPLICIT_TYPE_ID(paddle::dialect::AddNOp) +IR_DECLARE_EXPLICIT_TYPE_ID(paddle::dialect::SplitGradOp) +IR_DECLARE_EXPLICIT_TYPE_ID(paddle::dialect::AddN_Op) +IR_DECLARE_EXPLICIT_TYPE_ID(paddle::dialect::AddNWithKernelOp) +IR_DECLARE_EXPLICIT_TYPE_ID(paddle::dialect::FusedGemmEpilogueOp) +IR_DECLARE_EXPLICIT_TYPE_ID(paddle::dialect::FusedGemmEpilogueGradOp) +IR_DECLARE_EXPLICIT_TYPE_ID(paddle::dialect::IfOp) +#endif diff --git a/paddle/fluid/ir/dialect/paddle_dialect/ir/pd_manual_api.h b/paddle/fluid/pir/dialect/operator/ir/manual_op_vjp.cc similarity index 61% rename from paddle/fluid/ir/dialect/paddle_dialect/ir/pd_manual_api.h rename to paddle/fluid/pir/dialect/operator/ir/manual_op_vjp.cc index 5eba73e5182bd..7f58a434f554a 100644 --- a/paddle/fluid/ir/dialect/paddle_dialect/ir/pd_manual_api.h +++ b/paddle/fluid/pir/dialect/operator/ir/manual_op_vjp.cc @@ -12,19 +12,20 @@ // See the License for the specific language governing permissions and // limitations under the License. -#pragma once +#include "paddle/fluid/pir/dialect/operator/ir/op_attribute.h" +#include "paddle/fluid/pir/dialect/operator/ir/pd_op.h" +#include "paddle/fluid/primitive/rule/vjp/vjp.h" +#include "paddle/fluid/primitive/type/lazy_tensor.h" +#include "paddle/phi/common/int_array.h" +#include "paddle/pir/core/builtin_op.h" +#include "paddle/pir/core/op_base.h" -#include - -#include "paddle/ir/core/value.h" -#include "paddle/phi/common/data_type.h" -#include "paddle/phi/common/place.h" +// TODO(wanghao107) +// this file will be generated in pd_op.cc namespace paddle { namespace dialect { +using IntArray = paddle::experimental::IntArray; -ir::OpResult split_grad(std::vector out_grads, ir::OpResult axis); - -ir::OpResult split_grad(std::vector out_grads, int axis); } // namespace dialect } // namespace paddle diff --git a/paddle/fluid/ir/dialect/paddle_dialect/ir/pd_meta_tensor.cc b/paddle/fluid/pir/dialect/operator/ir/meta_tensor.cc similarity index 95% rename from paddle/fluid/ir/dialect/paddle_dialect/ir/pd_meta_tensor.cc rename to paddle/fluid/pir/dialect/operator/ir/meta_tensor.cc index 2da7b098a6556..1985413ecb95d 100644 --- a/paddle/fluid/ir/dialect/paddle_dialect/ir/pd_meta_tensor.cc +++ b/paddle/fluid/pir/dialect/operator/ir/meta_tensor.cc @@ -12,9 +12,9 @@ // See the License for the specific language governing permissions and // limitations under the License. -#include "paddle/fluid/ir/dialect/paddle_dialect/ir/pd_meta_tensor.h" +#include "paddle/fluid/pir/dialect/operator/ir/meta_tensor.h" -#include "paddle/ir/core/enforce.h" +#include "paddle/pir/core/enforce.h" namespace paddle { namespace dialect { diff --git a/paddle/fluid/ir/dialect/paddle_dialect/ir/pd_meta_tensor.h b/paddle/fluid/pir/dialect/operator/ir/meta_tensor.h similarity index 100% rename from paddle/fluid/ir/dialect/paddle_dialect/ir/pd_meta_tensor.h rename to paddle/fluid/pir/dialect/operator/ir/meta_tensor.h diff --git a/paddle/fluid/ir/dialect/paddle_dialect/ir/pd_attribute.cc b/paddle/fluid/pir/dialect/operator/ir/op_attribute.cc similarity index 85% rename from paddle/fluid/ir/dialect/paddle_dialect/ir/pd_attribute.cc rename to paddle/fluid/pir/dialect/operator/ir/op_attribute.cc index 72cc98447e10e..3b69d68eb65f3 100644 --- a/paddle/fluid/ir/dialect/paddle_dialect/ir/pd_attribute.cc +++ b/paddle/fluid/pir/dialect/operator/ir/op_attribute.cc @@ -12,7 +12,7 @@ // See the License for the specific language governing permissions and // limitations under the License. -#include "paddle/fluid/ir/dialect/paddle_dialect/ir/pd_attribute.h" +#include "paddle/fluid/pir/dialect/operator/ir/op_attribute.h" namespace paddle { namespace dialect { @@ -29,18 +29,18 @@ phi::DataLayout DataLayoutAttribute::data() const { } phi::Scalar ScalarAttribute::data() { - if (isa()) { - return phi::Scalar(dyn_cast().data()); - } else if (isa()) { - return phi::Scalar(dyn_cast().data()); - } else if (isa()) { - return phi::Scalar(dyn_cast().data()); - } else if (isa()) { - return phi::Scalar(dyn_cast().data()); - } else if (isa()) { - return phi::Scalar(dyn_cast().data()); - } else if (isa()) { - return phi::Scalar(dyn_cast().AsString()); + if (isa()) { + return phi::Scalar(dyn_cast().data()); + } else if (isa()) { + return phi::Scalar(dyn_cast().data()); + } else if (isa()) { + return phi::Scalar(dyn_cast().data()); + } else if (isa()) { + return phi::Scalar(dyn_cast().data()); + } else if (isa()) { + return phi::Scalar(dyn_cast().data()); + } else if (isa()) { + return phi::Scalar(dyn_cast().AsString()); } else { PADDLE_THROW(phi::errors::Unimplemented( "Unsupported ir attribute when casting it into " @@ -48,7 +48,7 @@ phi::Scalar ScalarAttribute::data() { } } -IntArrayAttribute IntArrayAttribute::Parse(ir::IrParser &parser) { // NOLINT +IntArrayAttribute IntArrayAttribute::Parse(pir::IrParser &parser) { // NOLINT Token buket_token = parser.ConsumeToken(); std::vector vec{}; while (parser.PeekToken().val_ != "]") { @@ -66,7 +66,7 @@ IntArrayAttribute IntArrayAttribute::Parse(ir::IrParser &parser) { // NOLINT // |int32|uint64|int64|float32|complex64 // |complex128|Undefined|psting|flaot16 // |bfloat16|num_data_types|all_dtype -DataTypeAttribute DataTypeAttribute::Parse(ir::IrParser &parser) { // NOLINT +DataTypeAttribute DataTypeAttribute::Parse(pir::IrParser &parser) { // NOLINT std::unordered_map StringToDataType{ {"bool", phi::DataType::BOOL}, {"uint8", phi::DataType::UINT8}, @@ -96,7 +96,7 @@ DataTypeAttribute DataTypeAttribute::Parse(ir::IrParser &parser) { // NOLINT // Parse a PlaceAttribute // PlaceAttribute := Place(cpu)|Place(gpu:0)|Place(gpu_pinned) // |Place(xpu:0)|Place(ipu:0)|Place(:0)|undefined -PlaceAttribute PlaceAttribute::Parse(ir::IrParser &parser) { // NOLINT +PlaceAttribute PlaceAttribute::Parse(pir::IrParser &parser) { // NOLINT std::unordered_map StringToPlace{ {"cpu", phi::CPUPlace{}}, {"gpu", phi::GPUPlace{}}, @@ -126,7 +126,7 @@ PlaceAttribute PlaceAttribute::Parse(ir::IrParser &parser) { // NOLINT // |SPARSE_COO|SPARSE_CSR|NDHWC // |NCDHW|PSTRING_UNION|STRIDED DataLayoutAttribute DataLayoutAttribute::Parse( - ir::IrParser &parser) { // NOLINT + pir::IrParser &parser) { // NOLINT std::unordered_map StringToDataLayout{ {"NHWC", phi::DataLayout::kNHWC}, {"NCHW", phi::DataLayout::kNCHW}, diff --git a/paddle/fluid/ir/dialect/paddle_dialect/ir/pd_attribute.h b/paddle/fluid/pir/dialect/operator/ir/op_attribute.h similarity index 65% rename from paddle/fluid/ir/dialect/paddle_dialect/ir/pd_attribute.h rename to paddle/fluid/pir/dialect/operator/ir/op_attribute.h index e1d3daab7191d..a47187774eeb6 100644 --- a/paddle/fluid/ir/dialect/paddle_dialect/ir/pd_attribute.h +++ b/paddle/fluid/pir/dialect/operator/ir/op_attribute.h @@ -14,17 +14,17 @@ #pragma once -#include "paddle/fluid/ir/dialect/paddle_dialect/ir/pd_attribute_storage.h" -#include "paddle/fluid/ir/dialect/paddle_dialect/utils/utils.h" -#include "paddle/ir/core/attribute.h" -#include "paddle/ir/core/builtin_attribute.h" -#include "paddle/ir/core/ir_parser.h" +#include "paddle/fluid/pir/dialect/operator/ir/attribute_storage.h" +#include "paddle/fluid/pir/dialect/operator/utils/utils.h" #include "paddle/phi/common/scalar.h" #include "paddle/phi/core/enforce.h" +#include "paddle/pir/core/attribute.h" +#include "paddle/pir/core/builtin_attribute.h" +#include "paddle/pir/core/parser/ir_parser.h" namespace paddle { namespace dialect { -class IntArrayAttribute : public ir::Attribute { +class IntArrayAttribute : public pir::Attribute { public: using Attribute::Attribute; @@ -35,32 +35,32 @@ class IntArrayAttribute : public ir::Attribute { return storage() < right.storage(); } - static IntArrayAttribute Parse(ir::IrParser &parser); // NOLINT + static IntArrayAttribute Parse(pir::IrParser &parser); // NOLINT const phi::IntArray &data() const; }; -class ScalarAttribute : public ir::Attribute { +class ScalarAttribute : public pir::Attribute { public: using Attribute::Attribute; - static bool classof(ir::Attribute val) { - return (val.type_id() == ir::BoolAttribute::type_id()) || - (val.type_id() == ir::FloatAttribute::type_id()) || - (val.type_id() == ir::DoubleAttribute::type_id()) || - (val.type_id() == ir::Int32Attribute::type_id()) || - (val.type_id() == ir::Int64Attribute::type_id()) || - (val.type_id() == ir::StrAttribute::type_id()); + static bool classof(pir::Attribute val) { + return (val.type_id() == pir::BoolAttribute::type_id()) || + (val.type_id() == pir::FloatAttribute::type_id()) || + (val.type_id() == pir::DoubleAttribute::type_id()) || + (val.type_id() == pir::Int32Attribute::type_id()) || + (val.type_id() == pir::Int64Attribute::type_id()) || + (val.type_id() == pir::StrAttribute::type_id()); } - static ir::Attribute get(ir::IrContext *ctx, phi::Scalar scalar) { + static pir::Attribute get(pir::IrContext *ctx, phi::Scalar scalar) { return TransToIrAttribute(scalar, ctx); } phi::Scalar data(); }; -class DataTypeAttribute : public ir::Attribute { +class DataTypeAttribute : public pir::Attribute { public: using Attribute::Attribute; @@ -71,12 +71,12 @@ class DataTypeAttribute : public ir::Attribute { return storage() < right.storage(); } - static DataTypeAttribute Parse(ir::IrParser &parser); // NOLINT + static DataTypeAttribute Parse(pir::IrParser &parser); // NOLINT phi::DataType data() const; }; -class PlaceAttribute : public ir::Attribute { +class PlaceAttribute : public pir::Attribute { public: using Attribute::Attribute; @@ -86,12 +86,12 @@ class PlaceAttribute : public ir::Attribute { return storage() < right.storage(); } - static PlaceAttribute Parse(ir::IrParser &parser); // NOLINT + static PlaceAttribute Parse(pir::IrParser &parser); // NOLINT phi::Place data() const; }; -class DataLayoutAttribute : public ir::Attribute { +class DataLayoutAttribute : public pir::Attribute { public: using Attribute::Attribute; @@ -102,7 +102,7 @@ class DataLayoutAttribute : public ir::Attribute { return storage() < right.storage(); } - static DataLayoutAttribute Parse(ir::IrParser &parser); // NOLINT + static DataLayoutAttribute Parse(pir::IrParser &parser); // NOLINT phi::DataLayout data() const; }; diff --git a/paddle/fluid/ir/dialect/paddle_dialect/ir/pd_dialect.cc b/paddle/fluid/pir/dialect/operator/ir/op_dialect.cc similarity index 75% rename from paddle/fluid/ir/dialect/paddle_dialect/ir/pd_dialect.cc rename to paddle/fluid/pir/dialect/operator/ir/op_dialect.cc index 82169dafc5969..2c85ea18d3da3 100644 --- a/paddle/fluid/ir/dialect/paddle_dialect/ir/pd_dialect.cc +++ b/paddle/fluid/pir/dialect/operator/ir/op_dialect.cc @@ -12,26 +12,26 @@ // See the License for the specific language governing permissions and // limitations under the License. -#include "paddle/fluid/ir/dialect/paddle_dialect/ir/pd_dialect.h" -#include "paddle/fluid/ir/dialect/paddle_dialect/ir/pd_attribute.h" +#include "paddle/fluid/pir/dialect/operator/ir/op_dialect.h" +#include "paddle/fluid/pir/dialect/operator/ir/op_attribute.h" // NOTE(zhangbo9674): File pd_op.h is generated by op_gen.py, see details in -// paddle/fluid/ir/dialect/CMakeLists.txt. -#include "paddle/fluid/ir/dialect/paddle_dialect/ir/pd_op.h" -#include "paddle/fluid/ir/dialect/paddle_dialect/ir/pd_type.h" -#include "paddle/fluid/ir/dialect/paddle_dialect/ir/pd_type_storage.h" -#include "paddle/fluid/ir/dialect/paddle_dialect/transforms/param_to_variable.h" -#include "paddle/ir/core/ir_printer.h" -#include "paddle/ir/core/utils.h" +// paddle/fluid/pir/dialect/CMakeLists.txt. +#include "paddle/fluid/pir/dialect/operator/ir/op_type.h" +#include "paddle/fluid/pir/dialect/operator/ir/pd_op.h" +#include "paddle/fluid/pir/dialect/operator/ir/type_storage.h" +#include "paddle/fluid/pir/dialect/operator/transforms/param_to_variable.h" +#include "paddle/pir/core/ir_printer.h" +#include "paddle/pir/core/utils.h" namespace paddle { namespace dialect { -PaddleDialect::PaddleDialect(ir::IrContext *context) - : ir::Dialect(name(), context, ir::TypeId::get()) { +OperatorDialect::OperatorDialect(pir::IrContext *context) + : pir::Dialect(name(), context, pir::TypeId::get()) { initialize(); } -void PaddleDialect::initialize() { +void OperatorDialect::initialize() { RegisterTypes(); RegisterTypes(); @@ -42,12 +42,12 @@ void PaddleDialect::initialize() { // NOTE(zhangbo9674): GET_OP_LIST is defined in pd_op.h which is // generated by op_gen.py, see details in - // paddle/fluid/ir/dialect/CMakeLists.txt. - // NOTE(Ruting)GET_MANUAL_OP_LIST is define in pd_manual_op.h" + // paddle/fluid/pir/dialect/CMakeLists.txt. + // NOTE(Ruting)GET_MANUAL_OP_LIST is define in manual_op.h" // use RegisterOps when list has more than two ops. RegisterOps< #define GET_OP_LIST -#include "paddle/fluid/ir/dialect/paddle_dialect/ir/pd_op.h" // NOLINT +#include "paddle/fluid/pir/dialect/operator/ir/pd_op.h" // NOLINT >(); RegisterOps(); } -void PaddleDialect::PrintType(ir::Type type, std::ostream &os) const { +void OperatorDialect::PrintType(pir::Type type, std::ostream &os) const { os << type.dialect().name(); os << '.'; if (auto tensor_type = type.dyn_cast()) { @@ -82,7 +82,8 @@ void PaddleDialect::PrintType(ir::Type type, std::ostream &os) const { } } -void PaddleDialect::PrintAttribute(ir::Attribute attr, std::ostream &os) const { +void OperatorDialect::PrintAttribute(pir::Attribute attr, + std::ostream &os) const { os << "(" << attr.dialect().name(); os << '.'; if (auto int_array_attr = attr.dyn_cast()) { @@ -90,7 +91,7 @@ void PaddleDialect::PrintAttribute(ir::Attribute attr, std::ostream &os) const { os << "IntArray)" << "["; const auto &inner_data = data.GetData(); - ir::PrintInterleave( + pir::PrintInterleave( inner_data.begin(), inner_data.end(), [&os](int64_t i) { os << i; }, @@ -107,8 +108,8 @@ void PaddleDialect::PrintAttribute(ir::Attribute attr, std::ostream &os) const { } } -ir::Type PaddleDialect::ParseType(ir::IrParser &parser) { // NOLINT - parser.ConsumeAToken("pd.tensor"); +pir::Type OperatorDialect::ParseType(pir::IrParser &parser) { // NOLINT + parser.ConsumeAToken("pd_op.tensor"); parser.ConsumeAToken("<"); std::vector dim{}; Token dim_token = parser.PeekToken(); @@ -126,7 +127,7 @@ ir::Type PaddleDialect::ParseType(ir::IrParser &parser) { // NOLINT } } phi::DDim ddim = phi::make_ddim(dim); - ir::Type dtype = parser.ParseType(); + pir::Type dtype = parser.ParseType(); std::vector> lod; std::vector lodv; lodv.push_back(0); @@ -136,7 +137,8 @@ ir::Type PaddleDialect::ParseType(ir::IrParser &parser) { // NOLINT parser.ctx, dtype, ddim, phi::DataLayout::UNDEFINED, lod, 0); } -ir::Attribute PaddleDialect::ParseAttribute(ir::IrParser &parser) { // NOLINT +pir::Attribute OperatorDialect::ParseAttribute( + pir::IrParser &parser) { // NOLINT std::string type_name = parser.ConsumeToken().val_; std::string attribute_name = type_name.substr(type_name.find('.') + 1, std::string::npos); @@ -155,8 +157,8 @@ ir::Attribute PaddleDialect::ParseAttribute(ir::IrParser &parser) { // NOLINT } } -void PaddleDialect::PrintOperation(ir::Operation *op, - ir::IrPrinter &printer) const { +void OperatorDialect::PrintOperation(pir::Operation *op, + pir::IrPrinter &printer) const { if (auto if_op = op->dyn_cast()) { if_op.Print(printer); } else { @@ -167,4 +169,4 @@ void PaddleDialect::PrintOperation(ir::Operation *op, } // namespace dialect } // namespace paddle -IR_DEFINE_EXPLICIT_TYPE_ID(paddle::dialect::PaddleDialect) +IR_DEFINE_EXPLICIT_TYPE_ID(paddle::dialect::OperatorDialect) diff --git a/paddle/fluid/ir/dialect/paddle_dialect/ir/pd_dialect.h b/paddle/fluid/pir/dialect/operator/ir/op_dialect.h similarity index 54% rename from paddle/fluid/ir/dialect/paddle_dialect/ir/pd_dialect.h rename to paddle/fluid/pir/dialect/operator/ir/op_dialect.h index 285a796982f85..bc85b789a058b 100644 --- a/paddle/fluid/ir/dialect/paddle_dialect/ir/pd_dialect.h +++ b/paddle/fluid/pir/dialect/operator/ir/op_dialect.h @@ -14,25 +14,25 @@ #pragma once -#include "paddle/ir/core/dialect.h" +#include "paddle/pir/core/dialect.h" namespace paddle { namespace dialect { -class PaddleDialect : public ir::Dialect { +class OperatorDialect : public pir::Dialect { public: - explicit PaddleDialect(ir::IrContext* context); + explicit OperatorDialect(pir::IrContext* context); - static const char* name() { return "pd"; } + static const char* name() { return "pd_op"; } - ir::Type ParseType(ir::IrParser& parser) override; // NOLINT - ir::Attribute ParseAttribute(ir::IrParser& parser) override; // NOLINT + pir::Type ParseType(pir::IrParser& parser) override; // NOLINT + pir::Attribute ParseAttribute(pir::IrParser& parser) override; // NOLINT - void PrintType(ir::Type type, std::ostream& os) const override; - void PrintAttribute(ir::Attribute type, std::ostream& os) const override; + void PrintType(pir::Type type, std::ostream& os) const override; + void PrintAttribute(pir::Attribute type, std::ostream& os) const override; - void PrintOperation(ir::Operation* op, - ir::IrPrinter& printer) const override; // NOLINT + void PrintOperation(pir::Operation* op, + pir::IrPrinter& printer) const override; // NOLINT private: void initialize(); @@ -41,4 +41,4 @@ class PaddleDialect : public ir::Dialect { } // namespace dialect } // namespace paddle -IR_DECLARE_EXPLICIT_TYPE_ID(paddle::dialect::PaddleDialect) +IR_DECLARE_EXPLICIT_TYPE_ID(paddle::dialect::OperatorDialect) diff --git a/paddle/fluid/ir/dialect/paddle_dialect/ir/pd_type.cc b/paddle/fluid/pir/dialect/operator/ir/op_type.cc similarity index 88% rename from paddle/fluid/ir/dialect/paddle_dialect/ir/pd_type.cc rename to paddle/fluid/pir/dialect/operator/ir/op_type.cc index 31ba23b0e1bbc..c9fc8bcd65b10 100644 --- a/paddle/fluid/ir/dialect/paddle_dialect/ir/pd_type.cc +++ b/paddle/fluid/pir/dialect/operator/ir/op_type.cc @@ -12,11 +12,11 @@ // See the License for the specific language governing permissions and // limitations under the License. -#include "paddle/fluid/ir/dialect/paddle_dialect/ir/pd_type.h" +#include "paddle/fluid/pir/dialect/operator/ir/op_type.h" namespace paddle { namespace dialect { -const ir::Type& SelectedRowsType::dtype() const { return storage()->dtype_; } +const pir::Type& SelectedRowsType::dtype() const { return storage()->dtype_; } const phi::DDim& SelectedRowsType::dims() const { return storage()->dims_; } diff --git a/paddle/fluid/ir/dialect/paddle_dialect/ir/pd_type.h b/paddle/fluid/pir/dialect/operator/ir/op_type.h similarity index 62% rename from paddle/fluid/ir/dialect/paddle_dialect/ir/pd_type.h rename to paddle/fluid/pir/dialect/operator/ir/op_type.h index 9525e1a88b346..3ee0d642e2e47 100644 --- a/paddle/fluid/ir/dialect/paddle_dialect/ir/pd_type.h +++ b/paddle/fluid/pir/dialect/operator/ir/op_type.h @@ -14,20 +14,23 @@ #pragma once -#include "paddle/fluid/ir/dialect/paddle_dialect/ir/pd_type_storage.h" -#include "paddle/ir/core/builtin_type.h" -#include "paddle/ir/core/type.h" +#include "paddle/fluid/pir/dialect/operator/ir/type_storage.h" +#include "paddle/pir/core/builtin_type.h" +#include "paddle/pir/core/builtin_type_interfaces.h" +#include "paddle/pir/core/type.h" namespace paddle { namespace dialect { -using DenseTensorType = ir::DenseTensorType; -class SelectedRowsType : public ir::Type { - public: - using Type::Type; - DECLARE_TYPE_UTILITY_FUNCTOR(SelectedRowsType, SelectedRowsTypeStorage); +using DenseTensorType = pir::DenseTensorType; +class SelectedRowsType : public pir::Type::TypeBase { + public: + using Base::Base; - const ir::Type &dtype() const; + const pir::Type &dtype() const; const phi::DDim &dims() const; diff --git a/paddle/fluid/ir/dialect/paddle_dialect/ir/pd_ops.yaml b/paddle/fluid/pir/dialect/operator/ir/ops.yaml similarity index 96% rename from paddle/fluid/ir/dialect/paddle_dialect/ir/pd_ops.yaml rename to paddle/fluid/pir/dialect/operator/ir/ops.yaml index da4c252af7217..bf80652d03134 100644 --- a/paddle/fluid/ir/dialect/paddle_dialect/ir/pd_ops.yaml +++ b/paddle/fluid/pir/dialect/operator/ir/ops.yaml @@ -106,6 +106,15 @@ param: [x, file_path, overwrite, save_as_fp16, save_to_memory] optional : out +- op : seed + args : (int seed, bool deterministic, str rng_name, bool force_cpu) + output : Tensor(out) + infer_meta: + func: SeedInferMeta + param: [seed] + kernel: + func: seed + - op : send_v2 args : (Tensor x, int ring_id = 0, int peer = 0, bool use_calc_stream = false, bool dynamic_shape = false) output : diff --git a/paddle/fluid/ir/dialect/paddle_dialect/ir/pd_ops_backward.yaml b/paddle/fluid/pir/dialect/operator/ir/ops_backward.yaml similarity index 100% rename from paddle/fluid/ir/dialect/paddle_dialect/ir/pd_ops_backward.yaml rename to paddle/fluid/pir/dialect/operator/ir/ops_backward.yaml diff --git a/paddle/fluid/ir/dialect/paddle_dialect/ir/pd_type_storage.h b/paddle/fluid/pir/dialect/operator/ir/type_storage.h similarity index 78% rename from paddle/fluid/ir/dialect/paddle_dialect/ir/pd_type_storage.h rename to paddle/fluid/pir/dialect/operator/ir/type_storage.h index 1a74b6d6c1059..e001f7b78716b 100644 --- a/paddle/fluid/ir/dialect/paddle_dialect/ir/pd_type_storage.h +++ b/paddle/fluid/pir/dialect/operator/ir/type_storage.h @@ -16,17 +16,17 @@ #include -#include "paddle/ir/core/builtin_type_storage.h" -#include "paddle/ir/core/type.h" -#include "paddle/ir/core/type_base.h" -#include "paddle/ir/core/utils.h" #include "paddle/phi/core/tensor_meta.h" +#include "paddle/pir/core/builtin_type_storage.h" +#include "paddle/pir/core/type.h" +#include "paddle/pir/core/type_base.h" +#include "paddle/pir/core/utils.h" namespace paddle { namespace dialect { -using DenseTensorTypeStorage = ir::DenseTensorTypeStorage; +using DenseTensorTypeStorage = pir::DenseTensorTypeStorage; -struct SelectedRowsTypeStorage : public ir::TypeStorage { +struct SelectedRowsTypeStorage : public pir::TypeStorage { using DataLayout = phi::DataLayout; using Dim = phi::DDim; using LoD = std::vector>; @@ -34,9 +34,9 @@ struct SelectedRowsTypeStorage : public ir::TypeStorage { /// \brief Declare ParamKey according to parameter type. /// using ParamKey = - std::tuple; + std::tuple; - SelectedRowsTypeStorage(const ir::Type& dtype, + SelectedRowsTypeStorage(const pir::Type& dtype, const phi::DDim& dims, const phi::DataLayout& layout, const phi::LoD& lod, @@ -66,22 +66,22 @@ struct SelectedRowsTypeStorage : public ir::TypeStorage { std::size_t hash_value = 317; // hash dtype hash_value = - ir::hash_combine(hash_value, std::hash()(std::get<0>(key))); + pir::hash_combine(hash_value, std::hash()(std::get<0>(key))); // hash dims hash_value = - ir::hash_combine(hash_value, std::hash()(std::get<1>(key))); + pir::hash_combine(hash_value, std::hash()(std::get<1>(key))); // hash layout - hash_value = ir::hash_combine( + hash_value = pir::hash_combine( hash_value, std::hash::type>()( static_cast::type>( std::get<2>(key)))); // hash lod hash_value = - ir::hash_combine(hash_value, std::hash()(std::get<3>(key))); + pir::hash_combine(hash_value, std::hash()(std::get<3>(key))); // hash offset hash_value = - ir::hash_combine(hash_value, std::hash()(std::get<4>(key))); + pir::hash_combine(hash_value, std::hash()(std::get<4>(key))); return hash_value; } @@ -100,7 +100,7 @@ struct SelectedRowsTypeStorage : public ir::TypeStorage { /// \brief DenseTensorTypeStorage include five parameters: dims, dtype, /// layout, lod, offset. /// - ir::Type dtype_; + pir::Type dtype_; phi::DDim dims_; phi::DataLayout layout_; phi::LoD lod_; diff --git a/paddle/fluid/ir/dialect/paddle_dialect/trait/CMakeLists.txt b/paddle/fluid/pir/dialect/operator/trait/CMakeLists.txt similarity index 83% rename from paddle/fluid/ir/dialect/paddle_dialect/trait/CMakeLists.txt rename to paddle/fluid/pir/dialect/operator/trait/CMakeLists.txt index 53c3060d6f182..0689edb35655e 100644 --- a/paddle/fluid/ir/dialect/paddle_dialect/trait/CMakeLists.txt +++ b/paddle/fluid/pir/dialect/operator/trait/CMakeLists.txt @@ -3,4 +3,4 @@ file(GLOB PD_INTERFACE_SRCS "*.cc") cc_library( pd_trait SRCS ${PD_INTERFACE_SRCS} - DEPS ir_core) + DEPS pir_core) diff --git a/paddle/fluid/pir/dialect/operator/trait/custom_vjp.h b/paddle/fluid/pir/dialect/operator/trait/custom_vjp.h new file mode 100644 index 0000000000000..1b1c7c08efca1 --- /dev/null +++ b/paddle/fluid/pir/dialect/operator/trait/custom_vjp.h @@ -0,0 +1,38 @@ +// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +/* +Custom VJP stands for manually implemented backward rules for composite +operators. CustomVjpTrait will be added for those composite operators that +defines custom vjp rules. Finally, by calling has_custom_vjp(op), users can +check whether an operator has a CustomVjpTrait, and thus check whether a custom +vjp rule is defined for that operator. +*/ + +#pragma once + +#include "paddle/pir/core/op_base.h" + +namespace paddle { +namespace dialect { +class CustomVjpTrait : public pir::OpTraitBase { + public: + explicit CustomVjpTrait(pir::Operation *op) + : pir::OpTraitBase(op) {} +}; + +} // namespace dialect +} // namespace paddle + +IR_DECLARE_EXPLICIT_TYPE_ID(paddle::dialect::CustomVjpTrait) diff --git a/paddle/fluid/ir/dialect/paddle_dialect/trait/inplace.h b/paddle/fluid/pir/dialect/operator/trait/inplace.h similarity index 80% rename from paddle/fluid/ir/dialect/paddle_dialect/trait/inplace.h rename to paddle/fluid/pir/dialect/operator/trait/inplace.h index 38dfaaeac000e..e50f1e3a8349d 100644 --- a/paddle/fluid/ir/dialect/paddle_dialect/trait/inplace.h +++ b/paddle/fluid/pir/dialect/operator/trait/inplace.h @@ -14,14 +14,14 @@ #pragma once -#include "paddle/ir/core/op_base.h" +#include "paddle/pir/core/op_base.h" namespace paddle { namespace dialect { -class InplaceTrait : public ir::OpTraitBase { +class InplaceTrait : public pir::OpTraitBase { public: - explicit InplaceTrait(ir::Operation *op) - : ir::OpTraitBase(op) {} + explicit InplaceTrait(pir::Operation *op) + : pir::OpTraitBase(op) {} }; } // namespace dialect diff --git a/paddle/fluid/ir/dialect/paddle_dialect/trait/trait.cc b/paddle/fluid/pir/dialect/operator/trait/trait.cc similarity index 78% rename from paddle/fluid/ir/dialect/paddle_dialect/trait/trait.cc rename to paddle/fluid/pir/dialect/operator/trait/trait.cc index c086b98e34bc7..2a5b7575959b9 100644 --- a/paddle/fluid/ir/dialect/paddle_dialect/trait/trait.cc +++ b/paddle/fluid/pir/dialect/operator/trait/trait.cc @@ -12,6 +12,8 @@ // See the License for the specific language governing permissions and // limitations under the License. -#include "paddle/fluid/ir/dialect/paddle_dialect/trait/inplace.h" +#include "paddle/fluid/pir/dialect/operator/trait/custom_vjp.h" +#include "paddle/fluid/pir/dialect/operator/trait/inplace.h" IR_DEFINE_EXPLICIT_TYPE_ID(paddle::dialect::InplaceTrait) +IR_DEFINE_EXPLICIT_TYPE_ID(paddle::dialect::CustomVjpTrait) diff --git a/paddle/fluid/ir/dialect/paddle_dialect/transforms/CMakeLists.txt b/paddle/fluid/pir/dialect/operator/transforms/CMakeLists.txt similarity index 68% rename from paddle/fluid/ir/dialect/paddle_dialect/transforms/CMakeLists.txt rename to paddle/fluid/pir/dialect/operator/transforms/CMakeLists.txt index 8d90edd3feb74..7116a12be50ef 100644 --- a/paddle/fluid/ir/dialect/paddle_dialect/transforms/CMakeLists.txt +++ b/paddle/fluid/pir/dialect/operator/transforms/CMakeLists.txt @@ -1,4 +1,4 @@ cc_library( param_to_variable SRCS param_to_variable.cc - DEPS pd_dialect_core) + DEPS pd_op_dialect_core) diff --git a/paddle/fluid/ir/dialect/paddle_dialect/transforms/param_to_variable.cc b/paddle/fluid/pir/dialect/operator/transforms/param_to_variable.cc similarity index 73% rename from paddle/fluid/ir/dialect/paddle_dialect/transforms/param_to_variable.cc rename to paddle/fluid/pir/dialect/operator/transforms/param_to_variable.cc index 0113e38b8fd5e..1d93e27c59b0b 100644 --- a/paddle/fluid/ir/dialect/paddle_dialect/transforms/param_to_variable.cc +++ b/paddle/fluid/pir/dialect/operator/transforms/param_to_variable.cc @@ -12,19 +12,19 @@ // See the License for the specific language governing permissions and // limitations under the License. -#include "paddle/fluid/ir/dialect/paddle_dialect/transforms/param_to_variable.h" +#include "paddle/fluid/pir/dialect/operator/transforms/param_to_variable.h" #include "paddle/fluid/framework/convert_utils.h" #include "paddle/fluid/framework/data_type.h" -#include "paddle/fluid/ir/dialect/paddle_dialect/ir/pd_type.h" -#include "paddle/fluid/ir/dialect/paddle_dialect/utils/utils.h" +#include "paddle/fluid/pir/dialect/operator/ir/op_type.h" +#include "paddle/fluid/pir/dialect/operator/utils/utils.h" #include "paddle/phi/common/data_type.h" #include "paddle/phi/core/dense_tensor.h" namespace paddle { namespace dialect { std::shared_ptr -ParameterConvertInterface::ParameterToVariable(ir::Parameter *parameter) { +ParameterConvertInterface::ParameterToVariable(pir::Parameter *parameter) { if (parameter->type().isa()) { VLOG(4) << "Convert a DenseTensor Parameter to a variable."; std::shared_ptr var = @@ -56,21 +56,21 @@ ParameterConvertInterface::ParameterToVariable(ir::Parameter *parameter) { } } -std::unique_ptr ParameterConvertInterface::VariableToParameter( +std::unique_ptr ParameterConvertInterface::VariableToParameter( paddle::framework::Variable *var) { if (var->IsType()) { phi::DenseTensor *tensor = var->GetMutable(); // Get Meta - ir::IrContext *ctx = ir::IrContext::Instance(); - ir::Type data_type = TransToIrDataType(tensor->dtype(), ctx); + pir::IrContext *ctx = pir::IrContext::Instance(); + pir::Type data_type = TransToIrDataType(tensor->dtype(), ctx); void *data = tensor->data(); - ir::Type dense_tensor_type = DenseTensorType::get(ctx, - data_type, - tensor->dims(), - tensor->layout(), - tensor->lod(), - tensor->meta().offset); - return std::make_unique( + pir::Type dense_tensor_type = DenseTensorType::get(ctx, + data_type, + tensor->dims(), + tensor->layout(), + tensor->lod(), + tensor->meta().offset); + return std::make_unique( data, tensor->numel() * phi::SizeOf(tensor->dtype()), dense_tensor_type); diff --git a/paddle/fluid/ir/dialect/paddle_dialect/transforms/param_to_variable.h b/paddle/fluid/pir/dialect/operator/transforms/param_to_variable.h similarity index 76% rename from paddle/fluid/ir/dialect/paddle_dialect/transforms/param_to_variable.h rename to paddle/fluid/pir/dialect/operator/transforms/param_to_variable.h index 4194cbae53ddf..bdb7bed12c970 100644 --- a/paddle/fluid/ir/dialect/paddle_dialect/transforms/param_to_variable.h +++ b/paddle/fluid/pir/dialect/operator/transforms/param_to_variable.h @@ -14,21 +14,21 @@ #pragma once #include "paddle/fluid/framework/variable.h" -#include "paddle/ir/core/dialect_interface.h" -#include "paddle/ir/core/parameter.h" +#include "paddle/pir/core/dialect_interface.h" +#include "paddle/pir/core/parameter.h" namespace paddle { namespace dialect { class ParameterConvertInterface - : public ir::DialectInterface::Base { + : public pir::DialectInterface::Base { public: - explicit ParameterConvertInterface(ir::Dialect* dialect) : Base(dialect) {} + explicit ParameterConvertInterface(pir::Dialect* dialect) : Base(dialect) {} // NOTE(zhangbo): Only support new a CPU Variable. std::shared_ptr ParameterToVariable( - ir::Parameter* parameter); + pir::Parameter* parameter); - std::unique_ptr VariableToParameter( + std::unique_ptr VariableToParameter( paddle::framework::Variable* var); }; diff --git a/paddle/fluid/ir/dialect/paddle_dialect/utils/CMakeLists.txt b/paddle/fluid/pir/dialect/operator/utils/CMakeLists.txt similarity index 64% rename from paddle/fluid/ir/dialect/paddle_dialect/utils/CMakeLists.txt rename to paddle/fluid/pir/dialect/operator/utils/CMakeLists.txt index 325f13f619b51..58eafb2cc3921 100644 --- a/paddle/fluid/ir/dialect/paddle_dialect/utils/CMakeLists.txt +++ b/paddle/fluid/pir/dialect/operator/utils/CMakeLists.txt @@ -1,5 +1,5 @@ cc_library(op_yaml_info_parser SRCS op_yaml_info_parser.cc) cc_library( - pd_dialect_utils + pd_op_dialect_utils SRCS utils.cc - DEPS pd_dialect_core) + DEPS pd_op_dialect_core) diff --git a/paddle/fluid/ir/dialect/paddle_dialect/utils/op_yaml_info_parser.cc b/paddle/fluid/pir/dialect/operator/utils/op_yaml_info_parser.cc similarity index 98% rename from paddle/fluid/ir/dialect/paddle_dialect/utils/op_yaml_info_parser.cc rename to paddle/fluid/pir/dialect/operator/utils/op_yaml_info_parser.cc index 8b5be8ff00cfd..eeb41ed3620ac 100644 --- a/paddle/fluid/ir/dialect/paddle_dialect/utils/op_yaml_info_parser.cc +++ b/paddle/fluid/pir/dialect/operator/utils/op_yaml_info_parser.cc @@ -12,7 +12,7 @@ // See the License for the specific language governing permissions and // limitations under the License. -#include "paddle/fluid/ir/dialect/paddle_dialect/utils/op_yaml_info_parser.h" +#include "paddle/fluid/pir/dialect/operator/utils/op_yaml_info_parser.h" namespace paddle { namespace dialect { diff --git a/paddle/fluid/ir/dialect/paddle_dialect/utils/op_yaml_info_parser.h b/paddle/fluid/pir/dialect/operator/utils/op_yaml_info_parser.h similarity index 97% rename from paddle/fluid/ir/dialect/paddle_dialect/utils/op_yaml_info_parser.h rename to paddle/fluid/pir/dialect/operator/utils/op_yaml_info_parser.h index acbc1b8e19649..9557a3d5b7763 100644 --- a/paddle/fluid/ir/dialect/paddle_dialect/utils/op_yaml_info_parser.h +++ b/paddle/fluid/pir/dialect/operator/utils/op_yaml_info_parser.h @@ -14,7 +14,7 @@ #pragma once -#include "paddle/fluid/ir/dialect/paddle_dialect/interface/op_yaml_info.h" +#include "paddle/fluid/pir/dialect/operator/interface/op_yaml_info.h" namespace paddle { namespace dialect { diff --git a/paddle/fluid/ir/dialect/paddle_dialect/utils/op_yaml_info_util.h b/paddle/fluid/pir/dialect/operator/utils/op_yaml_info_util.h similarity index 96% rename from paddle/fluid/ir/dialect/paddle_dialect/utils/op_yaml_info_util.h rename to paddle/fluid/pir/dialect/operator/utils/op_yaml_info_util.h index 3df6ce5e22c15..462e88f4da327 100644 --- a/paddle/fluid/ir/dialect/paddle_dialect/utils/op_yaml_info_util.h +++ b/paddle/fluid/pir/dialect/operator/utils/op_yaml_info_util.h @@ -14,9 +14,9 @@ #pragma once -#include "paddle/fluid/ir/dialect/paddle_dialect/ir/pd_type_storage.h" -#include "paddle/ir/core/builtin_attribute.h" -#include "paddle/ir/core/builtin_type.h" +#include "paddle/fluid/pir/dialect/operator/ir/type_storage.h" +#include "paddle/pir/core/builtin_attribute.h" +#include "paddle/pir/core/builtin_type.h" namespace paddle { namespace dialect { diff --git a/paddle/fluid/ir/dialect/paddle_dialect/utils/utils.cc b/paddle/fluid/pir/dialect/operator/utils/utils.cc similarity index 60% rename from paddle/fluid/ir/dialect/paddle_dialect/utils/utils.cc rename to paddle/fluid/pir/dialect/operator/utils/utils.cc index e0ec875ca00d6..4681b9b100122 100644 --- a/paddle/fluid/ir/dialect/paddle_dialect/utils/utils.cc +++ b/paddle/fluid/pir/dialect/operator/utils/utils.cc @@ -12,24 +12,29 @@ // See the License for the specific language governing permissions and // limitations under the License. -#include "paddle/fluid/ir/dialect/paddle_dialect/utils/utils.h" -#include "paddle/fluid/ir/dialect/paddle_dialect/ir/pd_attribute.h" +#include "paddle/fluid/pir/dialect/operator/utils/utils.h" +#include "paddle/fluid/pir/dialect/operator/ir/op_attribute.h" namespace paddle { namespace dialect { const std::unordered_set LegacyOpList = { - "pd.load_combine", - "pd.c_concat", - "pd.c_broadcast_", - "pd.fused_bn_add_activation_", - "pd.fused_bn_add_activation_grad", - "pd.c_sync_calc_stream_", - "pd.c_sync_comm_stream_", - "pd.send_v2", - "pd.recv_v2", - "pd.c_allreduce_sum", - "pd.c_allreduce_sum_"}; + "pd_op.load_combine", + "pd_op.c_concat", + "pd_op.c_broadcast_", + "pd_op.fused_bn_add_activation_", + "pd_op.fused_bn_add_activation_grad", + "pd_op.c_sync_calc_stream_", + "pd_op.c_sync_comm_stream_", + "pd_op.send_v2", + "pd_op.recv_v2", + "pd_op.c_allreduce_sum", + "pd_op.c_allreduce_sum_", + "pd_op.c_reduce_sum", + "pd_op.c_reduce_sum_", + "pd_op.c_allreduce_max_", + "pd_op.c_allgather", + "pd_op.seed"}; enum class AttrType { UNDEFINED = 0, @@ -53,20 +58,20 @@ enum class AttrType { NUM_ATTR_TYPES, }; -static inline AttrType GetAttributeType(const ir::Attribute& attr) { - if (attr.isa()) { +static inline AttrType GetAttributeType(const pir::Attribute& attr) { + if (attr.isa()) { return AttrType::BOOL; - } else if (attr.isa()) { + } else if (attr.isa()) { return AttrType::FLOAT; - } else if (attr.isa()) { + } else if (attr.isa()) { return AttrType::DOUBLE; - } else if (attr.isa()) { + } else if (attr.isa()) { return AttrType::INT32; - } else if (attr.isa()) { + } else if (attr.isa()) { return AttrType::INT64; - } else if (attr.isa()) { + } else if (attr.isa()) { return AttrType::ARRAY; - } else if (attr.isa()) { + } else if (attr.isa()) { return AttrType::STRING; } else if (attr.isa()) { return AttrType::INT_ARRAY; @@ -81,53 +86,54 @@ static inline AttrType GetAttributeType(const ir::Attribute& attr) { } } -static std::unordered_map> +static std::unordered_map< + AttrType, + std::function> kAttrCastMap = { {AttrType::BOOL, - [](const ir::Attribute& attr) { - return VariantType{attr.dyn_cast().data()}; + [](const pir::Attribute& attr) { + return VariantType{attr.dyn_cast().data()}; }}, {AttrType::FLOAT, - [](const ir::Attribute& attr) { - return VariantType{attr.dyn_cast().data()}; + [](const pir::Attribute& attr) { + return VariantType{attr.dyn_cast().data()}; }}, {AttrType::DOUBLE, - [](const ir::Attribute& attr) { - return VariantType{attr.dyn_cast().data()}; + [](const pir::Attribute& attr) { + return VariantType{attr.dyn_cast().data()}; }}, {AttrType::INT32, - [](const ir::Attribute& attr) { - return VariantType{attr.dyn_cast().data()}; + [](const pir::Attribute& attr) { + return VariantType{attr.dyn_cast().data()}; }}, {AttrType::INT64, - [](const ir::Attribute& attr) { - return VariantType{attr.dyn_cast().data()}; + [](const pir::Attribute& attr) { + return VariantType{attr.dyn_cast().data()}; }}, {AttrType::INT_ARRAY, - [](const ir::Attribute& attr) { + [](const pir::Attribute& attr) { return VariantType{ attr.dyn_cast() .data() .GetData()}; }}, {AttrType::STRING, - [](const ir::Attribute& attr) { - return VariantType{attr.dyn_cast().AsString()}; + [](const pir::Attribute& attr) { + return VariantType{attr.dyn_cast().AsString()}; }}, {AttrType::DATA_TYPE, - [](const ir::Attribute& attr) { + [](const pir::Attribute& attr) { return VariantType{ attr.dyn_cast().data()}; }}, {AttrType::PLACE, - [](const ir::Attribute& attr) { + [](const pir::Attribute& attr) { return VariantType{ attr.dyn_cast().data()}; }}, {AttrType::ARRAY, - [](const ir::Attribute& attr) { - auto attr_vec = attr.dyn_cast().AsVector(); + [](const pir::Attribute& attr) { + auto attr_vec = attr.dyn_cast().AsVector(); if (attr_vec.size() == 0) { return VariantType{std::vector()}; } @@ -137,37 +143,44 @@ static std::unordered_map vec_bools; for (auto vec_element : attr_vec) { vec_bools.push_back( - vec_element.dyn_cast().data()); + vec_element.dyn_cast().data()); } return VariantType{vec_bools}; } else if (element_type == AttrType::INT32) { std::vector vec_int32; for (auto vec_element : attr_vec) { vec_int32.push_back( - vec_element.dyn_cast().data()); + vec_element.dyn_cast().data()); } return VariantType{vec_int32}; } else if (element_type == AttrType::INT64) { std::vector vec_int64; for (auto vec_element : attr_vec) { vec_int64.push_back( - vec_element.dyn_cast().data()); + vec_element.dyn_cast().data()); } return VariantType{vec_int64}; } else if (element_type == AttrType::FLOAT) { std::vector vec_float; for (auto vec_element : attr_vec) { vec_float.push_back( - vec_element.dyn_cast().data()); + vec_element.dyn_cast().data()); } return VariantType{vec_float}; } else if (element_type == AttrType::DOUBLE) { std::vector vec_double; for (auto vec_element : attr_vec) { vec_double.push_back( - vec_element.dyn_cast().data()); + vec_element.dyn_cast().data()); } return VariantType{vec_double}; + } else if (element_type == AttrType::STRING) { + std::vector vec_string; + for (auto vec_element : attr_vec) { + vec_string.push_back( + vec_element.dyn_cast().AsString()); + } + return VariantType{vec_string}; } else { PADDLE_THROW(phi::errors::Unimplemented( "Unsupported ir Attribute type when casting it into " @@ -176,7 +189,7 @@ static std::unordered_map()) { +static inline phi::DataType TransToPhiDataType(pir::Type dtype) { + if (dtype.isa()) { return phi::DataType::BFLOAT16; - } else if (dtype.isa()) { + } else if (dtype.isa()) { return phi::DataType::FLOAT16; - } else if (dtype.isa()) { + } else if (dtype.isa()) { return phi::DataType::FLOAT32; - } else if (dtype.isa()) { + } else if (dtype.isa()) { return phi::DataType::FLOAT64; - } else if (dtype.isa()) { + } else if (dtype.isa()) { return phi::DataType::UINT8; - } else if (dtype.isa()) { + } else if (dtype.isa()) { return phi::DataType::INT8; - } else if (dtype.isa()) { + } else if (dtype.isa()) { return phi::DataType::INT16; - } else if (dtype.isa()) { + } else if (dtype.isa()) { return phi::DataType::INT32; - } else if (dtype.isa()) { + } else if (dtype.isa()) { return phi::DataType::INT64; - } else if (dtype.isa()) { + } else if (dtype.isa()) { return phi::DataType::INT32; - } else if (dtype.isa()) { + } else if (dtype.isa()) { return phi::DataType::BOOL; - } else if (dtype.isa()) { + } else if (dtype.isa()) { return phi::DataType::COMPLEX64; - } else if (dtype.isa()) { + } else if (dtype.isa()) { return phi::DataType::COMPLEX128; } else { PADDLE_THROW(phi::errors::Unimplemented( @@ -66,36 +66,36 @@ static inline phi::DataType TransToPhiDataType(ir::Type dtype) { // use phi::DataType::INT32 for IndexType from builtin type to phi::DataType, // but only use INT32 not IndexType from phi::DataType type to builtin type. -static inline ir::Type TransToIrDataType(phi::DataType dtype, - ir::IrContext* ctx = nullptr) { +static inline pir::Type TransToIrDataType(phi::DataType dtype, + pir::IrContext* ctx = nullptr) { if (ctx == nullptr) { - ctx = ir::IrContext::Instance(); + ctx = pir::IrContext::Instance(); } switch (dtype) { case phi::DataType::BFLOAT16: - return ir::BFloat16Type::get(ctx); + return pir::BFloat16Type::get(ctx); case phi::DataType::FLOAT16: - return ir::Float16Type::get(ctx); + return pir::Float16Type::get(ctx); case phi::DataType::FLOAT32: - return ir::Float32Type::get(ctx); + return pir::Float32Type::get(ctx); case phi::DataType::FLOAT64: - return ir::Float64Type::get(ctx); + return pir::Float64Type::get(ctx); case phi::DataType::UINT8: - return ir::UInt8Type::get(ctx); + return pir::UInt8Type::get(ctx); case phi::DataType::INT8: - return ir::Int8Type::get(ctx); + return pir::Int8Type::get(ctx); case phi::DataType::INT16: - return ir::Int16Type::get(ctx); + return pir::Int16Type::get(ctx); case phi::DataType::INT32: - return ir::Int32Type::get(ctx); + return pir::Int32Type::get(ctx); case phi::DataType::INT64: - return ir::Int64Type::get(ctx); + return pir::Int64Type::get(ctx); case phi::DataType::BOOL: - return ir::BoolType::get(ctx); + return pir::BoolType::get(ctx); case phi::DataType::COMPLEX64: - return ir::Complex64Type::get(ctx); + return pir::Complex64Type::get(ctx); case phi::DataType::COMPLEX128: - return ir::Complex128Type::get(ctx); + return pir::Complex128Type::get(ctx); default: PADDLE_THROW(phi::errors::Unimplemented( "Unsupported phi data type `%s` when casting it into " @@ -104,22 +104,22 @@ static inline ir::Type TransToIrDataType(phi::DataType dtype, } } -static inline ir::Attribute TransToIrAttribute(phi::Scalar scalar, - ir::IrContext* ctx = nullptr) { +static inline pir::Attribute TransToIrAttribute(phi::Scalar scalar, + pir::IrContext* ctx = nullptr) { if (ctx == nullptr) { - ctx = ir::IrContext::Instance(); + ctx = pir::IrContext::Instance(); } switch (scalar.dtype()) { case phi::DataType::FLOAT32: - return ir::FloatAttribute::get(ctx, scalar.to()); + return pir::FloatAttribute::get(ctx, scalar.to()); case phi::DataType::FLOAT64: - return ir::DoubleAttribute::get(ctx, scalar.to()); + return pir::DoubleAttribute::get(ctx, scalar.to()); case phi::DataType::INT32: - return ir::Int32Attribute::get(ctx, scalar.to()); + return pir::Int32Attribute::get(ctx, scalar.to()); case phi::DataType::INT64: - return ir::Int64Attribute::get(ctx, scalar.to()); + return pir::Int64Attribute::get(ctx, scalar.to()); case phi::DataType::BOOL: - return ir::BoolAttribute::get(ctx, scalar.to()); + return pir::BoolAttribute::get(ctx, scalar.to()); default: PADDLE_THROW(phi::errors::Unimplemented( "Unsupported phi data type `%s` when casting it into " @@ -166,7 +166,7 @@ inline DataType VarTypeToDataType( } } -VariantType GetAttributeData(const ir::Attribute& attr); +VariantType GetAttributeData(const pir::Attribute& attr); bool IsLegacyOp(const std::string& name); diff --git a/paddle/fluid/ir/phi_kernel_adaptor/CMakeLists.txt b/paddle/fluid/pir/phi_kernel_adaptor/CMakeLists.txt similarity index 56% rename from paddle/fluid/ir/phi_kernel_adaptor/CMakeLists.txt rename to paddle/fluid/pir/phi_kernel_adaptor/CMakeLists.txt index 1df1cc06db594..e1f8db179be6b 100644 --- a/paddle/fluid/ir/phi_kernel_adaptor/CMakeLists.txt +++ b/paddle/fluid/pir/phi_kernel_adaptor/CMakeLists.txt @@ -1,4 +1,4 @@ -# All source files of pd_dialect, except for the source file of op, which is generated in the compilation directory. +# All source files of pd_op_dialect, except for the source file of op, which is generated in the compilation directory. file(GLOB PHI_KERNEL_ADAPTOR_SRCS "*.cc") cc_library( diff --git a/paddle/fluid/ir/phi_kernel_adaptor/phi_kernel_adaptor.h b/paddle/fluid/pir/phi_kernel_adaptor/phi_kernel_adaptor.h similarity index 62% rename from paddle/fluid/ir/phi_kernel_adaptor/phi_kernel_adaptor.h rename to paddle/fluid/pir/phi_kernel_adaptor/phi_kernel_adaptor.h index bb1b284ea1b6c..47c0d39856d2f 100644 --- a/paddle/fluid/ir/phi_kernel_adaptor/phi_kernel_adaptor.h +++ b/paddle/fluid/pir/phi_kernel_adaptor/phi_kernel_adaptor.h @@ -14,23 +14,23 @@ #pragma once -#include "paddle/fluid/ir/dialect/paddle_dialect/interface/infermeta.h" -#include "paddle/fluid/ir/dialect/paddle_dialect/interface/op_yaml_info.h" -#include "paddle/fluid/ir/dialect/paddle_dialect/ir/pd_dialect.h" -#include "paddle/fluid/ir/dialect/paddle_dialect/ir/pd_op.h" -#include "paddle/fluid/ir/dialect/paddle_dialect/ir/pd_type.h" -#include "paddle/fluid/ir/dialect/paddle_dialect/utils/op_yaml_info_parser.h" -#include "paddle/fluid/ir/dialect/paddle_dialect/utils/op_yaml_info_util.h" -#include "paddle/fluid/ir/dialect/paddle_dialect/utils/utils.h" -#include "paddle/ir/core/builtin_attribute.h" -#include "paddle/ir/core/builtin_dialect.h" -#include "paddle/ir/core/builtin_op.h" -#include "paddle/ir/core/ir_context.h" -#include "paddle/ir/core/program.h" -#include "paddle/ir/core/utils.h" +#include "paddle/fluid/pir/dialect/operator/interface/infermeta.h" +#include "paddle/fluid/pir/dialect/operator/interface/op_yaml_info.h" +#include "paddle/fluid/pir/dialect/operator/ir/op_dialect.h" +#include "paddle/fluid/pir/dialect/operator/ir/op_type.h" +#include "paddle/fluid/pir/dialect/operator/ir/pd_op.h" +#include "paddle/fluid/pir/dialect/operator/utils/op_yaml_info_parser.h" +#include "paddle/fluid/pir/dialect/operator/utils/op_yaml_info_util.h" +#include "paddle/fluid/pir/dialect/operator/utils/utils.h" #include "paddle/phi/core/meta_tensor.h" #include "paddle/phi/infermeta/binary.h" #include "paddle/phi/kernels/elementwise_add_kernel.h" +#include "paddle/pir/core/builtin_attribute.h" +#include "paddle/pir/core/builtin_dialect.h" +#include "paddle/pir/core/builtin_op.h" +#include "paddle/pir/core/ir_context.h" +#include "paddle/pir/core/program.h" +#include "paddle/pir/core/utils.h" #include "paddle/fluid/framework/scope.h" #include "paddle/fluid/framework/tensor.h" @@ -43,19 +43,19 @@ #include "paddle/fluid/platform/init.h" -#include "paddle/fluid/ir/dialect/paddle_dialect/ir/pd_attribute.h" -#include "paddle/fluid/ir/dialect/paddle_kernel_dialect/ir/kernel_attribute.h" +#include "paddle/fluid/pir/dialect/kernel/ir/kernel_attribute.h" +#include "paddle/fluid/pir/dialect/operator/ir/op_attribute.h" #include "glog/logging.h" -#include "paddle/fluid/ir/phi_kernel_adaptor/phi_kernel_util.h" +#include "paddle/fluid/pir/phi_kernel_adaptor/phi_kernel_util.h" class PhiKernelAdaptor { public: explicit PhiKernelAdaptor(paddle::framework::Scope* scope) : scope_(scope) {} - void run_kernel_prog(ir::Program* program) { + void run_kernel_prog(pir::Program* program) { auto block = program->block(); - std::unordered_map value_2_var_name; + std::unordered_map value_2_var_name; std::unordered_map variable_2_var_name; std::map var_name_2_id; @@ -70,9 +70,9 @@ class PhiKernelAdaptor { &variable_2_var_name, &var_name_2_id, &variable_list); - ir::IrContext* ctx = ir::IrContext::Instance(); + pir::IrContext* ctx = pir::IrContext::Instance(); - ctx->GetOrRegisterDialect(); + ctx->GetOrRegisterDialect(); auto* dev_ctx = phi::DeviceContextPool::Instance().Get(phi::CPUPlace()); phi::Place cpu_place(phi::AllocationType::CPU); @@ -80,9 +80,9 @@ class PhiKernelAdaptor { auto attr_map = (*it)->attributes(); auto op_name = - attr_map.at("op_name").dyn_cast().AsString(); + attr_map.at("op_name").dyn_cast().AsString(); - ir::OpInfo op1_info = ctx->GetRegisteredOpInfo(op_name); + pir::OpInfo op1_info = ctx->GetRegisteredOpInfo(op_name); auto impl = op1_info.GetInterfaceImpl(); @@ -96,7 +96,7 @@ class PhiKernelAdaptor { phi::InferMetaContext ctx; paddle::dialect::OpYamlInfoParser op_yaml_info_parser(yaml_info); - ir::BuildPhiContext< + pir::BuildPhiContext< phi::InferMetaContext, phi::MetaTensor, phi::MetaTensor, @@ -108,7 +108,7 @@ class PhiKernelAdaptor { infer_meta_impl->infer_meta_(&ctx); auto kernel_name = - attr_map.at("kernel_name").dyn_cast().AsString(); + attr_map.at("kernel_name").dyn_cast().AsString(); auto kernel_key = attr_map.at("kernel_key") .dyn_cast() .data(); @@ -118,17 +118,17 @@ class PhiKernelAdaptor { phi::KernelContext kernel_ctx(dev_ctx); - ir::BuildPhiContext, - paddle::small_vector, - true>((*it), - value_2_var_name, - scope_, - nullptr, - op_yaml_info_parser, - &kernel_ctx); + pir::BuildPhiContext, + paddle::small_vector, + true>((*it), + value_2_var_name, + scope_, + nullptr, + op_yaml_info_parser, + &kernel_ctx); kernel_fn(&kernel_ctx); auto out_value = (*it)->result(0); diff --git a/paddle/fluid/ir/phi_kernel_adaptor/phi_kernel_util.cc b/paddle/fluid/pir/phi_kernel_adaptor/phi_kernel_util.cc similarity index 82% rename from paddle/fluid/ir/phi_kernel_adaptor/phi_kernel_util.cc rename to paddle/fluid/pir/phi_kernel_adaptor/phi_kernel_util.cc index c72641046f520..475e06f936f19 100644 --- a/paddle/fluid/ir/phi_kernel_adaptor/phi_kernel_util.cc +++ b/paddle/fluid/pir/phi_kernel_adaptor/phi_kernel_util.cc @@ -12,18 +12,18 @@ // See the License for the specific language governing permissions and // limitations under the License. -#include "paddle/fluid/ir/phi_kernel_adaptor/phi_kernel_util.h" - -#include "paddle/fluid/ir/dialect/paddle_dialect/interface/op_yaml_info.h" -#include "paddle/fluid/ir/dialect/paddle_dialect/ir/pd_dialect.h" -#include "paddle/fluid/ir/dialect/paddle_dialect/ir/pd_type.h" -#include "paddle/fluid/ir/dialect/paddle_dialect/utils/op_yaml_info_util.h" -#include "paddle/fluid/ir/dialect/paddle_dialect/utils/utils.h" -#include "paddle/ir/core/builtin_attribute.h" -#include "paddle/ir/core/ir_context.h" -#include "paddle/ir/core/program.h" -#include "paddle/ir/core/utils.h" +#include "paddle/fluid/pir/phi_kernel_adaptor/phi_kernel_util.h" + +#include "paddle/fluid/pir/dialect/operator/interface/op_yaml_info.h" +#include "paddle/fluid/pir/dialect/operator/ir/op_dialect.h" +#include "paddle/fluid/pir/dialect/operator/ir/op_type.h" +#include "paddle/fluid/pir/dialect/operator/utils/op_yaml_info_util.h" +#include "paddle/fluid/pir/dialect/operator/utils/utils.h" #include "paddle/phi/core/meta_tensor.h" +#include "paddle/pir/core/builtin_attribute.h" +#include "paddle/pir/core/ir_context.h" +#include "paddle/pir/core/program.h" +#include "paddle/pir/core/utils.h" #include "paddle/fluid/framework/scope.h" #include "paddle/fluid/framework/tensor.h" @@ -33,23 +33,23 @@ #include "paddle/fluid/framework/string_array.h" #include "paddle/fluid/framework/tensor_ref_array.h" -#include "paddle/fluid/ir/dialect/paddle_dialect/ir/pd_attribute.h" -#include "paddle/fluid/ir/dialect/paddle_dialect/utils/op_yaml_info_parser.h" -#include "paddle/fluid/ir/dialect/paddle_kernel_dialect/ir/kernel_attribute.h" -#include "paddle/fluid/ir/dialect/paddle_kernel_dialect/ir/kernel_type.h" #include "paddle/fluid/ir_adaptor/translator/op_compat_info.h" +#include "paddle/fluid/pir/dialect/kernel/ir/kernel_attribute.h" +#include "paddle/fluid/pir/dialect/kernel/ir/kernel_type.h" +#include "paddle/fluid/pir/dialect/operator/ir/op_attribute.h" +#include "paddle/fluid/pir/dialect/operator/utils/op_yaml_info_parser.h" #include "paddle/phi/core/enforce.h" #include "glog/logging.h" #include "paddle/fluid/framework/op_info.h" #include "paddle/fluid/framework/operator.h" -namespace ir { +namespace pir { -void AddNewData(ir::Value value, +void AddNewData(pir::Value value, std::string name, paddle::framework::Variable* var, - std::unordered_map* value_2_var_name, + std::unordered_map* value_2_var_name, std::unordered_map* variable_2_var_name, std::map* var_name_2_id, @@ -71,10 +71,10 @@ void AddNewData(ir::Value value, "The size of variable_list and var_name_2_id map should be equal")); } -void RenameData(ir::Value value, +void RenameData(pir::Value value, std::string new_name, std::string orig_name, - std::unordered_map* value_2_var_name, + std::unordered_map* value_2_var_name, std::unordered_map* variable_2_var_name, std::map* var_name_2_id) { @@ -104,11 +104,11 @@ using VariableNameMap = std::unordered_map; paddle::framework::Variable* CreateVar( - ir::Value value, + pir::Value value, paddle::framework::Scope* inner_scope, const std::string& var_name_prefix, bool force_persisable, - std::unordered_map* value_2_var_name, + std::unordered_map* value_2_var_name, std::unordered_map* variable_2_var_name, std::map* var_name_2_id, @@ -142,9 +142,9 @@ paddle::framework::Variable* CreateVar( } void CheckInputVars( - ir::Operation* op, + pir::Operation* op, const std::string& op_name, - const std::unordered_map& value_2_var_name) { + const std::unordered_map& value_2_var_name) { size_t input_num = op->num_operands(); if (input_num > 0) { for (size_t i = 0; i < input_num; ++i) { @@ -162,10 +162,10 @@ void CheckInputVars( } } -void BuildValue(ir::Value value, +void BuildValue(pir::Value value, paddle::framework::Scope* inner_scope, const std::string& var_name_prefix, - std::unordered_map* value_2_var_name, + std::unordered_map* value_2_var_name, std::unordered_map* variable_2_var_name, std::map* var_name_2_id, @@ -190,12 +190,12 @@ void BuildValue(ir::Value value, var->GetMutable(); } else if (value.type().isa()) { var->GetMutable(); - } else if (value.type().isa()) { + } else if (value.type().isa()) { auto tensor_array = var->GetMutable(); - for (size_t i = 0; i < value.type().dyn_cast().size(); + for (size_t i = 0; i < value.type().dyn_cast().size(); i++) { PADDLE_ENFORCE(value.type() - .dyn_cast()[i] + .dyn_cast()[i] .isa(), paddle::platform::errors::Fatal( "Element of VectorType output only support " @@ -219,10 +219,10 @@ void BuildValue(ir::Value value, } void HandleForSpecialOp( - ir::Operation* op, + pir::Operation* op, paddle::framework::Scope* inner_scope, const std::string& var_name_prefix, - std::unordered_map* value_2_var_name, + std::unordered_map* value_2_var_name, std::unordered_map* variable_2_var_name, std::map* var_name_2_id, @@ -230,13 +230,13 @@ void HandleForSpecialOp( std::string op_name = op->name(); if (op->attributes().count("op_name")) { op_name = - op->attributes().at("op_name").dyn_cast().AsString(); + op->attributes().at("op_name").dyn_cast().AsString(); } - if (op_name == "pd.fetch") { + if (op_name == "pd_op.fetch") { // fetch is a very special op, with no output auto fetch_src_name = - op->attributes().at("name").dyn_cast().AsString(); + op->attributes().at("name").dyn_cast().AsString(); auto fetch_var_name = fetch_src_name + "@fetch"; auto* var = const_cast(inner_scope->root()) @@ -253,13 +253,13 @@ void HandleForSpecialOp( variable_list); } - if (op_name == "pd.feed" || op_name == "pd.data") { + if (op_name == "pd_op.feed" || op_name == "pd_op.data") { VLOG(6) << "Handle for" << op_name; auto value = op->result(0); VLOG(6) << "link feed output to feed in variable" << inner_scope; std::string name = - op->attributes().at("name").dyn_cast().AsString(); + op->attributes().at("name").dyn_cast().AsString(); paddle::framework::Variable* var = inner_scope->FindVar(name); PADDLE_ENFORCE(var, paddle::platform::errors::InvalidArgument( @@ -310,7 +310,7 @@ void HandleForSpecialOp( VLOG(6) << "Handle for builtin.set_parameter:"; auto param_name = op->attributes() .at("parameter_name") - .dyn_cast() + .dyn_cast() .AsString(); auto value = op->operand_source(0); @@ -338,10 +338,10 @@ void HandleForSpecialOp( var_name_2_id); } - if (op_name == "pd.shadow_output") { - VLOG(6) << "Handle for pd.shadow_ouptut"; + if (op_name == "pd_op.shadow_output") { + VLOG(6) << "Handle for pd_op.shadow_ouptut"; auto var_name = - op->attributes().at("name").dyn_cast().AsString(); + op->attributes().at("name").dyn_cast().AsString(); auto value = op->operand_source(0); // change opreand name to param_name @@ -363,7 +363,7 @@ void HandleForSpecialOp( VLOG(6) << "Handle for builtin.get_parameter:"; auto param_name = op->attributes() .at("parameter_name") - .dyn_cast() + .dyn_cast() .AsString(); auto value = op->result(0); @@ -387,7 +387,7 @@ void HandleForSpecialOp( "input of buildin slice not in name map")); int index = - op->attributes().at("index").dyn_cast().data(); + op->attributes().at("index").dyn_cast().data(); auto in_var = inner_scope->FindVar(value_2_var_name->at(in_value)); auto variable_array = in_var->Get(); @@ -428,36 +428,36 @@ void HandleForSpecialOp( } void HandleForInplaceOp( - ir::Operation* op, + pir::Operation* op, paddle::framework::Scope* inner_scope, const std::string& var_name_prefix, - std::unordered_map* value_2_var_name, + std::unordered_map* value_2_var_name, std::unordered_map* variable_2_var_name, std::map* var_name_2_id, std::vector* variable_list) { if (op->num_results() < 1) return; - ir::IrContext* ctx = ir::IrContext::Instance(); + pir::IrContext* ctx = pir::IrContext::Instance(); std::string op_name = op->name(); if (op->attributes().count("op_name")) { op_name = - op->attributes().at("op_name").dyn_cast().AsString(); + op->attributes().at("op_name").dyn_cast().AsString(); } - ir::OpInfo op_info = ctx->GetRegisteredOpInfo(op_name); + pir::OpInfo op_info = ctx->GetRegisteredOpInfo(op_name); paddle::dialect::OpYamlInfoParser yaml_parser( op_info.GetInterfaceImpl() ->get_op_info_()); for (size_t i = 0; i < op->num_results(); ++i) { - ir::Value value = op->result(i); + pir::Value value = op->result(i); if (value.type().storage() == nullptr) { continue; } std::string value_name = yaml_parser.OutputNames()[i]; if (yaml_parser.HasInplace(value_name)) { const std::string& inplace_name = yaml_parser.InplaceName(value_name); - ir::Value inplace_value = + pir::Value inplace_value = op->operand_source(yaml_parser.InputName2Id().at(inplace_name)); std::string var_name = value_2_var_name->at(inplace_value); VLOG(4) << "inplace: " << value_name << " -> " << inplace_name @@ -465,7 +465,7 @@ void HandleForInplaceOp( value_2_var_name->emplace(value, var_name); } else if (yaml_parser.HasView(value_name)) { const std::string& view_name = yaml_parser.ViewName(value_name); - ir::Value view_value = + pir::Value view_value = op->operand_source(yaml_parser.InputName2Id().at(view_name)); const std::string& var_name = value_2_var_name->at(view_value); VLOG(4) << "view: " << value_name << " -> " << view_name @@ -485,10 +485,10 @@ void HandleForInplaceOp( // NOTE(zhiqiu): the persistable is created in inner_scope's root, and other is // created in inner_scope. -void BuildScope(const ir::Block& block, +void BuildScope(const pir::Block& block, paddle::framework::Scope* inner_scope, const std::string& var_name_prefix, - std::unordered_map* value_2_var_name, + std::unordered_map* value_2_var_name, std::unordered_map* variable_2_var_name, std::map* var_name_2_id, @@ -503,16 +503,16 @@ void BuildScope(const ir::Block& block, if (op->attributes().count("op_name")) { op_name = op->attributes() .at("op_name") - .dyn_cast() + .dyn_cast() .AsString(); } VLOG(4) << "build op:" << op_name; - if (op_name == "pd.feed" || op_name == "pd.fetch" || + if (op_name == "pd_op.feed" || op_name == "pd_op.fetch" || op_name == "builtin.combine" || op_name == "builtin.set_parameter" || op_name == "builtin.get_parameter" || op_name == "builtin.slice" || - op_name == "builtin.split" || op_name == "pd.data" || - op_name == "pd.shadow_output") { + op_name == "builtin.split" || op_name == "pd_op.data" || + op_name == "pd_op.shadow_output") { HandleForSpecialOp(op, inner_scope, var_name_prefix, @@ -529,7 +529,7 @@ void BuildScope(const ir::Block& block, if (op->attributes().count("is_inplace") != 0 && op->attributes() .at("is_inplace") - .dyn_cast() + .dyn_cast() .data()) { HandleForInplaceOp(op, inner_scope, @@ -559,8 +559,8 @@ void BuildScope(const ir::Block& block, } void BuildRuntimeContext( - ir::Operation* op, - const std::unordered_map& name_map, + pir::Operation* op, + const std::unordered_map& name_map, paddle::framework::Scope* scope, paddle::framework::Scope* local_scope, const paddle::dialect::OpYamlInfoParser& op_yaml_info, @@ -584,7 +584,7 @@ void BuildRuntimeContext( true, phi::errors::NotFound("param [%s] MUST in name2id map", name)); auto index = op_yaml_info.InputName2Id().at(name); - ir::Value ptr = op->operand_source(index); + pir::Value ptr = op->operand_source(index); auto in_var_name = name_map.at(ptr); VLOG(6) << "ctx->EmplaceBackInput: " << name << "\t" << in_var_name; @@ -602,7 +602,7 @@ void BuildRuntimeContext( auto& output_name_list = op_yaml_info.OutputNames(); for (size_t i = 0; i < output_name_list.size(); ++i) { auto name = output_name_list[i]; - ir::Value ptr = op->result(i); + pir::Value ptr = op->result(i); auto in_var_name = name_map.at(ptr); VLOG(6) << "ctx->EmplaceBackOutput: " << name << "\t" << in_var_name; @@ -618,7 +618,7 @@ void BuildRuntimeContext( if (type.isa() || type.isa()) { runtime_ctx->outputs[legacy_arg_name] = {var}; - } else if (type.isa()) { + } else if (type.isa()) { auto var_ref = var->Get(); std::vector vec_tmp; vec_tmp.reserve(var_ref.size()); @@ -629,14 +629,14 @@ void BuildRuntimeContext( } else { PADDLE_THROW(phi::errors::Unimplemented( "only support AllocatedDenseTensor, AllocatedSelectedRowsType and " - "ir::vector type")); + "pir::vector type")); } } } std::shared_ptr BuildOperatorBase( - ir::Operation* op, - const std::unordered_map& name_map, + pir::Operation* op, + const std::unordered_map& name_map, const paddle::dialect::OpYamlInfoParser& op_yaml_info, const std::unordered_map& variable_2_var_name, @@ -658,7 +658,7 @@ std::shared_ptr BuildOperatorBase( true, phi::errors::NotFound("param [%s] MUST in name2id map", name)); auto index = op_yaml_info.InputName2Id().at(name); - ir::Value ptr = op->operand_source(index); + pir::Value ptr = op->operand_source(index); auto in_var_name = name_map.at(ptr); @@ -672,52 +672,52 @@ std::shared_ptr BuildOperatorBase( for (auto& name : attr_name_list) { auto& val = op_attr_map.at(name); - if (val.isa()) { - attr_map[name] = val.dyn_cast().AsString(); - } else if (val.isa()) { - attr_map[name] = val.dyn_cast().data(); - } else if (val.isa()) { - attr_map[name] = val.dyn_cast().data(); - } else if (val.isa()) { - attr_map[name] = val.dyn_cast().data(); - } else if (val.isa()) { - attr_map[name] = val.dyn_cast().data(); - } else if (val.isa()) { - attr_map[name] = val.dyn_cast().data(); - } else if (val.isa()) { - auto array_list = val.dyn_cast().AsVector(); + if (val.isa()) { + attr_map[name] = val.dyn_cast().AsString(); + } else if (val.isa()) { + attr_map[name] = val.dyn_cast().data(); + } else if (val.isa()) { + attr_map[name] = val.dyn_cast().data(); + } else if (val.isa()) { + attr_map[name] = val.dyn_cast().data(); + } else if (val.isa()) { + attr_map[name] = val.dyn_cast().data(); + } else if (val.isa()) { + attr_map[name] = val.dyn_cast().data(); + } else if (val.isa()) { + auto array_list = val.dyn_cast().AsVector(); PADDLE_ENFORCE( array_list.size() > 0, paddle::platform::errors::Fatal("Attribute %s is empty", name)); - if (array_list[0].isa()) { + if (array_list[0].isa()) { std::vector vec_int; for (auto attribute : array_list) { - vec_int.push_back(attribute.dyn_cast().data()); + vec_int.push_back(attribute.dyn_cast().data()); } attr_map[name] = vec_int; - } else if (array_list[0].isa()) { + } else if (array_list[0].isa()) { std::vector vec_int64; for (auto attribute : array_list) { - vec_int64.push_back(attribute.dyn_cast().data()); + vec_int64.push_back(attribute.dyn_cast().data()); } attr_map[name] = vec_int64; - } else if (array_list[0].isa()) { + } else if (array_list[0].isa()) { std::vector vec_bool; for (auto attribute : array_list) { - vec_bool.push_back(attribute.dyn_cast().data()); + vec_bool.push_back(attribute.dyn_cast().data()); } attr_map[name] = vec_bool; - } else if (array_list[0].isa()) { + } else if (array_list[0].isa()) { std::vector vec_float; for (auto attribute : array_list) { - vec_float.push_back(attribute.dyn_cast().data()); + vec_float.push_back(attribute.dyn_cast().data()); } attr_map[name] = vec_float; - } else if (array_list[0].isa()) { + } else if (array_list[0].isa()) { std::vector vec_double; for (auto attribute : array_list) { vec_double.push_back( - attribute.dyn_cast().data()); + attribute.dyn_cast().data()); } attr_map[name] = vec_double; } else { @@ -740,7 +740,7 @@ std::shared_ptr BuildOperatorBase( auto& output_name_list = op_yaml_info.OutputNames(); for (size_t i = 0; i < output_name_list.size(); ++i) { auto name = output_name_list[i]; - ir::Value ptr = op->result(i); + pir::Value ptr = op->result(i); auto out_var_name = name_map.at(ptr); @@ -749,7 +749,7 @@ std::shared_ptr BuildOperatorBase( if (type.isa() || type.isa()) { out_name_map[legacy_arg_name].push_back(out_var_name); - } else if (type.isa()) { + } else if (type.isa()) { auto var = scope->FindVar(out_var_name); auto var_ref = var->Get(); for (size_t k = 0; k < var_ref.size(); ++k) { @@ -761,7 +761,7 @@ std::shared_ptr BuildOperatorBase( } else { PADDLE_THROW(phi::errors::Unimplemented( "only support AllocatedDenseTensor, AllocatedSelectedRowsType and " - "ir::vector type")); + "pir::vector type")); } } @@ -773,4 +773,4 @@ std::shared_ptr BuildOperatorBase( return res; } -} // namespace ir +} // namespace pir diff --git a/paddle/fluid/ir/phi_kernel_adaptor/phi_kernel_util.h b/paddle/fluid/pir/phi_kernel_adaptor/phi_kernel_util.h similarity index 74% rename from paddle/fluid/ir/phi_kernel_adaptor/phi_kernel_util.h rename to paddle/fluid/pir/phi_kernel_adaptor/phi_kernel_util.h index b1916d5418f77..037674467bc67 100644 --- a/paddle/fluid/ir/phi_kernel_adaptor/phi_kernel_util.h +++ b/paddle/fluid/pir/phi_kernel_adaptor/phi_kernel_util.h @@ -14,16 +14,16 @@ #pragma once -#include "paddle/fluid/ir/dialect/paddle_dialect/interface/op_yaml_info.h" -#include "paddle/fluid/ir/dialect/paddle_dialect/ir/pd_dialect.h" -#include "paddle/fluid/ir/dialect/paddle_dialect/ir/pd_type.h" -#include "paddle/fluid/ir/dialect/paddle_dialect/utils/op_yaml_info_util.h" -#include "paddle/fluid/ir/dialect/paddle_dialect/utils/utils.h" -#include "paddle/ir/core/builtin_attribute.h" -#include "paddle/ir/core/ir_context.h" -#include "paddle/ir/core/program.h" -#include "paddle/ir/core/utils.h" +#include "paddle/fluid/pir/dialect/operator/interface/op_yaml_info.h" +#include "paddle/fluid/pir/dialect/operator/ir/op_dialect.h" +#include "paddle/fluid/pir/dialect/operator/ir/op_type.h" +#include "paddle/fluid/pir/dialect/operator/utils/op_yaml_info_util.h" +#include "paddle/fluid/pir/dialect/operator/utils/utils.h" #include "paddle/phi/core/meta_tensor.h" +#include "paddle/pir/core/builtin_attribute.h" +#include "paddle/pir/core/ir_context.h" +#include "paddle/pir/core/program.h" +#include "paddle/pir/core/utils.h" #include "paddle/fluid/framework/new_executor/interpreter/execution_config.h" #include "paddle/fluid/framework/scope.h" @@ -33,36 +33,36 @@ #include "paddle/phi/core/kernel_context.h" #include "paddle/fluid/framework/operator.h" -#include "paddle/fluid/ir/dialect/paddle_dialect/ir/pd_attribute.h" -#include "paddle/fluid/ir/dialect/paddle_dialect/utils/op_yaml_info_parser.h" -#include "paddle/fluid/ir/dialect/paddle_kernel_dialect/ir/kernel_attribute.h" -#include "paddle/fluid/ir/dialect/paddle_kernel_dialect/ir/kernel_type.h" -#include "paddle/ir/core/type_name.h" +#include "paddle/fluid/pir/dialect/kernel/ir/kernel_attribute.h" +#include "paddle/fluid/pir/dialect/kernel/ir/kernel_type.h" +#include "paddle/fluid/pir/dialect/operator/ir/op_attribute.h" +#include "paddle/fluid/pir/dialect/operator/utils/op_yaml_info_parser.h" #include "paddle/phi/core/infermeta_utils.h" +#include "paddle/pir/core/type_name.h" #include "glog/logging.h" -namespace ir { -void BuildScope(const ir::Block& block, +namespace pir { +void BuildScope(const pir::Block& block, paddle::framework::Scope* inner_scope, const std::string& var_name_prefix, - std::unordered_map* value_2_var_name, + std::unordered_map* value_2_var_name, std::unordered_map* variable_2_var_name, std::map* var_name_2_id, std::vector* variable_list); void BuildRuntimeContext( - ir::Operation* op, - const std::unordered_map& name_map, + pir::Operation* op, + const std::unordered_map& name_map, paddle::framework::Scope* scope, paddle::framework::Scope* local_scope, const paddle::dialect::OpYamlInfoParser& op_yaml_info, paddle::framework::RuntimeContext* runtime_ctx); std::shared_ptr BuildOperatorBase( - ir::Operation* op, - const std::unordered_map& name_map, + pir::Operation* op, + const std::unordered_map& name_map, const paddle::dialect::OpYamlInfoParser& op_yaml_info, const std::unordered_map& variable_2_var_name, @@ -74,12 +74,13 @@ template -void BuildPhiContext(ir::Operation* op, - const std::unordered_map& name_map, - paddle::framework::Scope* scope, - paddle::framework::Scope* local_scope, - const paddle::dialect::OpYamlInfoParser& op_yaml_info, - Context* ctx) { +void BuildPhiContext( + pir::Operation* op, + const std::unordered_map& name_map, + paddle::framework::Scope* scope, + paddle::framework::Scope* local_scope, + const paddle::dialect::OpYamlInfoParser& op_yaml_info, + Context* ctx) { paddle::framework::Scope* inner_scope = local_scope != nullptr ? local_scope : scope; VLOG(6) << "Build " << get_type_name() << " in scope[" << scope @@ -96,7 +97,7 @@ void BuildPhiContext(ir::Operation* op, true, phi::errors::NotFound("param [%s] MUST in name2id map", t)); auto index = op_yaml_info.InputName2Id().at(t); - ir::Value ptr = op->operand_source(index); + pir::Value ptr = op->operand_source(index); if (!ptr) { phi::DenseTensor* ptr = nullptr; OutType in_ptr(ptr); @@ -142,7 +143,7 @@ void BuildPhiContext(ir::Operation* op, for (auto& t : vec_kernel_fn_attr_params) { if (name2id.count(t)) { // tensor attribute, get information from input - ir::Value ptr = op->operand_source(name2id.at(t)); + pir::Value ptr = op->operand_source(name2id.at(t)); auto in_var_name = name_map.at(ptr); @@ -153,7 +154,7 @@ void BuildPhiContext(ir::Operation* op, phi::Attribute attr = phi::TensorRef( &(inner_scope->FindVar(in_var_name)->Get())); ctx->EmplaceBackAttr(attr); - } else if (ptr.type().isa()) { + } else if (ptr.type().isa()) { auto& tensor_array = inner_scope->FindVar(in_var_name) ->Get(); if (tensor_array.size() == 1) { @@ -193,19 +194,20 @@ void BuildPhiContext(ir::Operation* op, } else if (attr_type_name == "paddle::dialect::DataTypeAttribute") { ctx->EmplaceBackAttr( attr_map[t].dyn_cast().data()); - } else if (attr_type_name == "ir::Int32Attribute") { - ctx->EmplaceBackAttr(attr_map[t].dyn_cast().data()); - } else if (attr_type_name == "ir::Int64Attribute") { - ctx->EmplaceBackAttr(attr_map[t].dyn_cast().data()); - } else if (attr_type_name == "ir::FloatAttribute") { - ctx->EmplaceBackAttr(attr_map[t].dyn_cast().data()); - } else if (attr_type_name == "ir::BoolAttribute") { - ctx->EmplaceBackAttr(attr_map[t].dyn_cast().data()); - } else if (attr_type_name == "ir::StrAttribute") { - ctx->EmplaceBackAttr(attr_map[t].dyn_cast().AsString()); + } else if (attr_type_name == "pir::Int32Attribute") { + ctx->EmplaceBackAttr(attr_map[t].dyn_cast().data()); + } else if (attr_type_name == "pir::Int64Attribute") { + ctx->EmplaceBackAttr(attr_map[t].dyn_cast().data()); + } else if (attr_type_name == "pir::FloatAttribute") { + ctx->EmplaceBackAttr(attr_map[t].dyn_cast().data()); + } else if (attr_type_name == "pir::BoolAttribute") { + ctx->EmplaceBackAttr(attr_map[t].dyn_cast().data()); + } else if (attr_type_name == "pir::StrAttribute") { + ctx->EmplaceBackAttr( + attr_map[t].dyn_cast().AsString()); } else if (attr_type_name == - "ir::ArrayAttribute") { - auto array_list = attr_map[t].dyn_cast().AsVector(); + "pir::ArrayAttribute") { + auto array_list = attr_map[t].dyn_cast().AsVector(); std::vector vec_res; if (array_list.size() > 0) { PADDLE_ENFORCE_EQ( @@ -220,29 +222,29 @@ void BuildPhiContext(ir::Operation* op, } } ctx->EmplaceBackAttr(vec_res); - } else if (attr_type_name == "ir::ArrayAttribute") { - auto array_list = attr_map[t].dyn_cast().AsVector(); + } else if (attr_type_name == "pir::ArrayAttribute") { + auto array_list = attr_map[t].dyn_cast().AsVector(); std::vector vec_res; if (array_list.size() > 0) { PADDLE_ENFORCE_EQ( - array_list[0].isa(), + array_list[0].isa(), true, phi::errors::Unimplemented( - "the 0th elementwise MUST be ir::Int32Attribute")); + "the 0th elementwise MUST be pir::Int32Attribute")); for (size_t i = 0; i < array_list.size(); ++i) { vec_res.push_back( - array_list[i].dyn_cast().data()); + array_list[i].dyn_cast().data()); } } ctx->EmplaceBackAttr(vec_res); - } else if (attr_type_name == "ir::ArrayAttribute") { - auto array_list = attr_map[t].dyn_cast().AsVector(); + } else if (attr_type_name == "pir::ArrayAttribute") { + auto array_list = attr_map[t].dyn_cast().AsVector(); std::vector vec_res; if (array_list.size() > 0) { - if (array_list[0].isa()) { + if (array_list[0].isa()) { for (size_t i = 0; i < array_list.size(); ++i) { vec_res.push_back( - array_list[i].dyn_cast().data()); + array_list[i].dyn_cast().data()); } } else { @@ -251,37 +253,37 @@ void BuildPhiContext(ir::Operation* op, } } ctx->EmplaceBackAttr(vec_res); - } else if (attr_type_name == "ir::ArrayAttribute") { - auto array_list = attr_map[t].dyn_cast().AsVector(); + } else if (attr_type_name == "pir::ArrayAttribute") { + auto array_list = attr_map[t].dyn_cast().AsVector(); std::vector vec_res; if (array_list.size() > 0) { PADDLE_ENFORCE_EQ( - array_list[0].isa(), + array_list[0].isa(), true, phi::errors::PreconditionNotMet( - "Element in array list MUST be ir::Int64Attribute ")); + "Element in array list MUST be pir::Int64Attribute ")); for (size_t i = 0; i < array_list.size(); ++i) { vec_res.push_back( - array_list[i].dyn_cast().data()); + array_list[i].dyn_cast().data()); } } ctx->EmplaceBackAttr(vec_res); - } else if (attr_type_name == "ir::ArrayAttribute") { - auto array_list = attr_map[t].dyn_cast().AsVector(); + } else if (attr_type_name == "pir::ArrayAttribute") { + auto array_list = attr_map[t].dyn_cast().AsVector(); std::vector vec_res; if (array_list.size() > 0) { PADDLE_ENFORCE_EQ( - array_list[0].isa(), + array_list[0].isa(), true, phi::errors::PreconditionNotMet( - "Element in array list MUST be ir::Int64Attribute ")); + "Element in array list MUST be pir::Int64Attribute ")); for (size_t i = 0; i < array_list.size(); ++i) { vec_res.push_back( - array_list[i].dyn_cast().data()); + array_list[i].dyn_cast().data()); } } ctx->EmplaceBackAttr(vec_res); @@ -300,7 +302,7 @@ void BuildPhiContext(ir::Operation* op, // TODO(phlrain): use var type instead of op name for (size_t i = 0; i < op->num_results(); ++i) { - ir::Value out_ptr = op->result(i); + pir::Value out_ptr = op->result(i); auto out_type = out_ptr.type(); if (out_type) { auto& name = name_map.at(out_ptr); @@ -320,7 +322,7 @@ void BuildPhiContext(ir::Operation* op, ctx->EmplaceBackOutput(OutType(const_cast( &(inner_scope->FindVar(name_map.at(out_ptr)) ->Get())))); - } else if (out_type.isa()) { + } else if (out_type.isa()) { OutListType outputs; auto& variable_array = inner_scope->FindVar(name_map.at(out_ptr)) ->Get(); @@ -348,4 +350,4 @@ void BuildPhiContext(ir::Operation* op, VLOG(6) << "Done build phi context"; } -} // namespace ir +} // namespace pir diff --git a/paddle/fluid/ir/transforms/CMakeLists.txt b/paddle/fluid/pir/transforms/CMakeLists.txt similarity index 71% rename from paddle/fluid/ir/transforms/CMakeLists.txt rename to paddle/fluid/pir/transforms/CMakeLists.txt index 36e06410d338a..ce2cb40f0eba4 100644 --- a/paddle/fluid/ir/transforms/CMakeLists.txt +++ b/paddle/fluid/pir/transforms/CMakeLists.txt @@ -1,12 +1,12 @@ cc_library( transform_general_functions SRCS transform_general_functions.cc - DEPS pd_dialect_core) + DEPS pd_op_dialect_core) cc_library( pd_op_to_kernel_pass SRCS pd_op_to_kernel_pass.cc - DEPS pd_kernel_dialect pd_dialect_core pd_dialect_utils) + DEPS pd_kernel_dialect pd_op_dialect_core pd_op_dialect_utils) cc_library( _constant_folding_pass @@ -16,4 +16,4 @@ cc_library( cc_library( pd_inplace_pass SRCS inplace_pass.cc - DEPS pd_dialect_core op_yaml_info_parser) + DEPS pd_op_dialect_core op_yaml_info_parser) diff --git a/paddle/fluid/ir/transforms/constant_folding_pass.cc b/paddle/fluid/pir/transforms/constant_folding_pass.cc similarity index 61% rename from paddle/fluid/ir/transforms/constant_folding_pass.cc rename to paddle/fluid/pir/transforms/constant_folding_pass.cc index 93699e3eae165..d3f78787841f0 100644 --- a/paddle/fluid/ir/transforms/constant_folding_pass.cc +++ b/paddle/fluid/pir/transforms/constant_folding_pass.cc @@ -12,71 +12,74 @@ // See the License for the specific language governing permissions and // limitations under the License. -#include "paddle/fluid/ir/transforms/constant_folding_pass.h" +#include "paddle/fluid/pir/transforms/constant_folding_pass.h" #include #include #include // NOTE(zhangbo9674): File pd_op.h is generated by op_gen.py, see details in -// paddle/fluid/ir/dialect/CMakeLists.txt. -#include "paddle/fluid/ir/dialect/paddle_dialect/ir/pd_op.h" +// paddle/fluid/pir/dialect/CMakeLists.txt. +#include "paddle/fluid/pir/dialect/operator/ir/pd_op.h" #include "paddle/fluid/framework/new_executor/interpretercore.h" #include "paddle/fluid/framework/scope.h" -#include "paddle/fluid/ir/dialect/paddle_dialect/ir/pd_dialect.h" -#include "paddle/fluid/ir/dialect/paddle_dialect/ir/pd_type.h" -#include "paddle/fluid/ir/transforms/pd_op_to_kernel_pass.h" -#include "paddle/fluid/ir/transforms/transform_general_functions.h" -#include "paddle/ir/core/builtin_op.h" -#include "paddle/ir/core/ir_context.h" -#include "paddle/ir/core/operation.h" -#include "paddle/ir/core/parameter.h" -#include "paddle/ir/core/program.h" -#include "paddle/ir/pass/pass.h" -#include "paddle/ir/pattern_rewrite/frozen_rewrite_pattern_set.h" -#include "paddle/ir/pattern_rewrite/pattern_match.h" -#include "paddle/ir/pattern_rewrite/pattern_rewrite_driver.h" +#include "paddle/fluid/pir/dialect/operator/ir/op_dialect.h" +#include "paddle/fluid/pir/dialect/operator/ir/op_type.h" +#include "paddle/fluid/pir/transforms/pd_op_to_kernel_pass.h" +#include "paddle/fluid/pir/transforms/transform_general_functions.h" #include "paddle/phi/common/place.h" #include "paddle/phi/core/dense_tensor.h" #include "paddle/phi/core/enforce.h" +#include "paddle/pir/core/builtin_op.h" +#include "paddle/pir/core/ir_context.h" +#include "paddle/pir/core/operation.h" +#include "paddle/pir/core/parameter.h" +#include "paddle/pir/core/program.h" +#include "paddle/pir/pass/pass.h" +#include "paddle/pir/pattern_rewrite/frozen_rewrite_pattern_set.h" +#include "paddle/pir/pattern_rewrite/pattern_match.h" +#include "paddle/pir/pattern_rewrite/pattern_rewrite_driver.h" namespace { -class ConstantFoldingPattern : public ir::RewritePattern { +class ConstantFoldingPattern : public pir::RewritePattern { public: - ConstantFoldingPattern(ir::IrContext* context, - ir::PatternBenefit benefit = 1, + ConstantFoldingPattern(pir::IrContext* context, + pir::PatternBenefit benefit = 1, const std::vector& generated_names = {}) : RewritePattern(MatchAnyOpTypeTag(), benefit, context, generated_names) { } - bool Match(ir::Operation* op) const override { + bool Match(pir::Operation* op) const override { // TODO(liuyuanle): Use trait to improve robustness. - if (op->dyn_cast() || - op->dyn_cast() || + if (op->dyn_cast() || + op->dyn_cast() || op->dyn_cast()) return false; // Inputs must come from get parameter op. for (uint32_t i = 0; i < op->num_operands(); ++i) - if (ir::GetDefiningOpForInput(op, i)->dyn_cast() == + if (pir::GetDefiningOpForInput(op, i)->dyn_cast() == nullptr) return false; return true; } - void Rewrite(ir::Operation* op, - ir::PatternRewriter& rewriter) const override { // NOLINT - ir::Program* program = op->GetParentProgram(); + void Rewrite(pir::Operation* op, + pir::PatternRewriter& rewriter) const override { // NOLINT + pir::Program* program = op->GetParentProgram(); auto temp_program = BuildProgramFromOperation(op); std::vector fetch_var_names; auto block = temp_program->block(); for (auto it = block->begin(); it != block->end(); ++it) { - if ((*it)->name() == "pd.fetch") { - size_t index = - (*it)->attributes().at("col").dyn_cast().data(); + if ((*it)->name() == "pd_op.fetch") { + size_t index = (*it) + ->attributes() + .at("col") + .dyn_cast() + .data(); if (fetch_var_names.size() < index + 1) { fetch_var_names.resize(index + 1); @@ -85,7 +88,7 @@ class ConstantFoldingPattern : public ir::RewritePattern { fetch_var_names[index] = (*it) ->attributes() .at("name") - .dyn_cast() + .dyn_cast() .AsString() + "@fetch"; } @@ -104,10 +107,11 @@ class ConstantFoldingPattern : public ir::RewritePattern { // TODO(liuyuanle): Support multiple output. auto out_tensor = PADDLE_GET_CONST(phi::DenseTensor, fetch_list[0]); - std::unique_ptr parameter = std::make_unique( - reinterpret_cast(out_tensor.data()), - out_tensor.numel() * phi::SizeOf(out_tensor.dtype()), - op->result(0).type()); + std::unique_ptr parameter = + std::make_unique( + reinterpret_cast(out_tensor.data()), + out_tensor.numel() * phi::SizeOf(out_tensor.dtype()), + op->result(0).type()); std::string param_name = "@constant_folding_pass@_" + std::to_string(suffix_++); @@ -119,20 +123,20 @@ class ConstantFoldingPattern : public ir::RewritePattern { program->SetParameter(param_name, std::move(parameter)); // rewriter.SetInsertionPoint(op); auto get_parameter_op = - rewriter.Build(param_name, op->result(0).type()); + rewriter.Build(param_name, op->result(0).type()); rewriter.ReplaceAllUsesWith(op->result(0), get_parameter_op->result(0)); rewriter.EraseOp(op); } private: - std::unique_ptr BuildProgramFromOperation( - ir::Operation* op) const { - auto program = std::make_unique(ir_context()); - ir::Builder builder = ir::Builder(ir_context(), program->block()); + std::unique_ptr BuildProgramFromOperation( + pir::Operation* op) const { + auto program = std::make_unique(ir_context()); + pir::Builder builder = pir::Builder(ir_context(), program->block()); // prepare op inputs - std::vector op_inputs; + std::vector op_inputs; for (uint32_t i = 0; i < op->num_operands(); i++) { PADDLE_ENFORCE_EQ( op->operand_source(i).type().isa(), @@ -141,22 +145,22 @@ class ConstantFoldingPattern : public ir::RewritePattern { "Op's input must be a dense tensor type.")); auto [param_name, param] = - ir::GetParameterFromValue(op->operand_source(i)); + pir::GetParameterFromValue(op->operand_source(i)); program->SetParameter(param_name, - std::make_unique(*param)); + std::make_unique(*param)); auto* param_var = scope_.FindVar(param_name); PADDLE_ENFORCE_NOT_NULL( param_var, phi::errors::InvalidArgument("Parameter var not in scope.")); - auto get_parameter_op = builder.Build( + auto get_parameter_op = builder.Build( param_name, op->operand_source(i).type()); op_inputs.push_back(get_parameter_op->result(0)); } // prepare op outputs - std::vector output_types; + std::vector output_types; for (uint32_t i = 0; i < op->num_results(); i++) { output_types.push_back(op->result(i).type()); } @@ -185,39 +189,39 @@ class ConstantFoldingPattern : public ir::RewritePattern { inline static paddle::framework::interpreter::ExecutionConfig exe_config_{}; }; -class ConstantFoldingPass : public ir::Pass { +class ConstantFoldingPass : public pir::Pass { public: // TODO(liuyuanle): Naming convention for pass. - ConstantFoldingPass() : ir::Pass("ConstantFoldingPass", 1) {} + ConstantFoldingPass() : pir::Pass("ConstantFoldingPass", 1) {} - bool Initialize(ir::IrContext* context) override { - ir::RewritePatternSet ps(context); + bool Initialize(pir::IrContext* context) override { + pir::RewritePatternSet ps(context); ps.Add(context); - patterns_ = ir::FrozenRewritePatternSet(std::move(ps)); + patterns_ = pir::FrozenRewritePatternSet(std::move(ps)); return true; } - void Run(ir::Operation* op) override { - ir::GreedyRewriteConfig cfg; + void Run(pir::Operation* op) override { + pir::GreedyRewriteConfig cfg; cfg.use_top_down_traversal = true; cfg.max_iterations = 10; - ir::ApplyPatternsGreedily(op->region(0), patterns_, cfg); + pir::ApplyPatternsGreedily(op->region(0), patterns_, cfg); } - bool CanApplyOn(ir::Operation* op) const override { + bool CanApplyOn(pir::Operation* op) const override { return op->name() == "builtin.module" && op->num_regions() > 0; } private: - ir::FrozenRewritePatternSet patterns_; + pir::FrozenRewritePatternSet patterns_; }; } // namespace -namespace ir { +namespace pir { std::unique_ptr CreateConstantFoldingPass() { return std::make_unique(); } -} // namespace ir +} // namespace pir diff --git a/paddle/fluid/ir/transforms/constant_folding_pass.h b/paddle/fluid/pir/transforms/constant_folding_pass.h similarity index 90% rename from paddle/fluid/ir/transforms/constant_folding_pass.h rename to paddle/fluid/pir/transforms/constant_folding_pass.h index 0c5ca794ad5bc..b49c9d90493b1 100644 --- a/paddle/fluid/ir/transforms/constant_folding_pass.h +++ b/paddle/fluid/pir/transforms/constant_folding_pass.h @@ -15,12 +15,12 @@ #pragma once #include -#include "paddle/ir/core/dll_decl.h" +#include "paddle/pir/core/dll_decl.h" -namespace ir { +namespace pir { class Pass; IR_API std::unique_ptr CreateConstantFoldingPass(); -} // namespace ir +} // namespace pir diff --git a/paddle/fluid/ir/transforms/inplace_pass.cc b/paddle/fluid/pir/transforms/inplace_pass.cc similarity index 70% rename from paddle/fluid/ir/transforms/inplace_pass.cc rename to paddle/fluid/pir/transforms/inplace_pass.cc index 222abc8344895..adfa5866799b9 100644 --- a/paddle/fluid/ir/transforms/inplace_pass.cc +++ b/paddle/fluid/pir/transforms/inplace_pass.cc @@ -12,25 +12,25 @@ // See the License for the specific language governing permissions and // limitations under the License. -#include "paddle/fluid/ir/transforms/inplace_pass.h" - -#include "paddle/fluid/ir/dialect/paddle_dialect/interface/op_yaml_info.h" -#include "paddle/fluid/ir/dialect/paddle_dialect/ir/pd_type.h" -#include "paddle/fluid/ir/dialect/paddle_dialect/trait/inplace.h" -#include "paddle/fluid/ir/dialect/paddle_dialect/utils/op_yaml_info_parser.h" -#include "paddle/fluid/ir/dialect/paddle_kernel_dialect/ir/kernel_attribute.h" -#include "paddle/fluid/ir/dialect/paddle_kernel_dialect/ir/kernel_dialect.h" -#include "paddle/fluid/ir/dialect/paddle_kernel_dialect/ir/kernel_type.h" -#include "paddle/ir/core/builtin_op.h" -#include "paddle/ir/core/operation.h" -#include "paddle/ir/pass/pass.h" -#include "paddle/ir/pass/pass_registry.h" +#include "paddle/fluid/pir/transforms/inplace_pass.h" + +#include "paddle/fluid/pir/dialect/kernel/ir/kernel_attribute.h" +#include "paddle/fluid/pir/dialect/kernel/ir/kernel_dialect.h" +#include "paddle/fluid/pir/dialect/kernel/ir/kernel_type.h" +#include "paddle/fluid/pir/dialect/operator/interface/op_yaml_info.h" +#include "paddle/fluid/pir/dialect/operator/ir/op_type.h" +#include "paddle/fluid/pir/dialect/operator/trait/inplace.h" +#include "paddle/fluid/pir/dialect/operator/utils/op_yaml_info_parser.h" +#include "paddle/pir/core/builtin_op.h" +#include "paddle/pir/core/operation.h" +#include "paddle/pir/pass/pass.h" +#include "paddle/pir/pass/pass_registry.h" namespace details { // NOTE(zhangbo): Which kind of value can be deleted? // (1) Value's type needs to be AllocatedDenseTensorType or // AllocatedSelectedRowsType; (2) Value's is not persisable. -static bool CanBeDeleted(ir::Value value) { +static bool CanBeDeleted(pir::Value value) { if (!value.type()) { return false; } @@ -41,17 +41,17 @@ static bool CanBeDeleted(ir::Value value) { if (value.GetDefiningOp()->HasAttribute(kAttrIsPersisable)) { return !(value.GetDefiningOp() ->attribute(kAttrIsPersisable) - .dyn_cast<::ir::ArrayAttribute>() - .AsVector()[value.dyn_cast<::ir::OpResult>().GetResultIndex()] - .dyn_cast<::ir::BoolAttribute>() + .dyn_cast() + .AsVector()[value.dyn_cast().GetResultIndex()] + .dyn_cast() .data()); } return true; } -static bool CanDoInplace(const std::unordered_set& eager_dels, - ir::Value input, - ir::Value output) { +static bool CanDoInplace(const std::unordered_set& eager_dels, + pir::Value input, + pir::Value output) { if (input.type() != output.type()) { VLOG(9) << " -- input's type != output's type, can't do inplace"; return false; @@ -63,16 +63,17 @@ static bool CanDoInplace(const std::unordered_set& eager_dels, return true; } -static bool IsNoNeedBuffer(ir::Operation* op, ir::Value value) { - if (op->dialect()->name().compare( - paddle::dialect::PaddleKernelDialect::name()) != 0) { +static bool IsNoNeedBuffer(pir::Operation* op, pir::Value value) { + if (op->dialect()->name().compare(paddle::dialect::KernelDialect::name()) != + 0) { VLOG(8) << op->name() << "is not a kernel_dialect op, no need buffer is false"; return false; } auto op_name = - op->attributes().at("op_name").dyn_cast<::ir::StrAttribute>().AsString(); - ir::OpInfo op_info = ir::IrContext::Instance()->GetRegisteredOpInfo(op_name); + op->attributes().at("op_name").dyn_cast().AsString(); + pir::OpInfo op_info = + pir::IrContext::Instance()->GetRegisteredOpInfo(op_name); if (op_info) { auto info_interface = op_info.GetInterfaceImpl(); @@ -90,27 +91,26 @@ static bool IsNoNeedBuffer(ir::Operation* op, ir::Value value) { return false; } -// NOTE(zhangbo): pd.feed's output and pd.fetch's input can not be eager +// NOTE(zhangbo): pd_op.feed's output and pd_op.fetch's input can not be eager // deleted. -static std::unordered_set GetSkipDeletionValues(ir::Block* block) { - std::unordered_set skip_dels; +static std::unordered_set GetSkipDeletionValues(pir::Block* block) { + std::unordered_set skip_dels; for (auto& op : *block) { - if (op->dialect()->name().compare( - paddle::dialect::PaddleKernelDialect::name()) != 0) { + if (op->dialect()->name().compare(paddle::dialect::KernelDialect::name()) != + 0) { continue; } IR_ENFORCE(op->attributes().count("op_name") > 0, "kernel_dialect op should own an 'op_name' attribute."); - auto upper_op_name = op->attributes() - .at("op_name") - .dyn_cast<::ir::StrAttribute>() - .AsString(); + auto upper_op_name = + op->attributes().at("op_name").dyn_cast().AsString(); - if (upper_op_name == "pd.feed" || upper_op_name == "pd.data") { + if (upper_op_name == "pd_op.feed" || upper_op_name == "pd_op.data") { skip_dels.insert(op->result(0)); continue; } - if (upper_op_name == "pd.fetch" || upper_op_name == "pd.shadow_output") { + if (upper_op_name == "pd_op.fetch" || + upper_op_name == "pd_op.shadow_output") { skip_dels.insert(op->operand_source(0)); continue; } @@ -121,20 +121,20 @@ static std::unordered_set GetSkipDeletionValues(ir::Block* block) { // NOTE(zhangbo): For inplace Pass, currently only the kernel_dialect operator // is supported. Therefore, this function only returns the values in the // kernel_dialect operator that can be eager deleted. -static std::unordered_map> -GetEagerDeletionValues(ir::Block* block) { - std::unordered_set skip_dels = GetSkipDeletionValues(block); +static std::unordered_map> +GetEagerDeletionValues(pir::Block* block) { + std::unordered_set skip_dels = GetSkipDeletionValues(block); - std::unordered_map del_value_2_op; + std::unordered_map del_value_2_op; for (auto& op : *block) { std::string upper_op_name = op->name(); - if (op->dialect()->name().compare( - paddle::dialect::PaddleKernelDialect::name()) == 0) { + if (op->dialect()->name().compare(paddle::dialect::KernelDialect::name()) == + 0) { IR_ENFORCE(op->attributes().count("op_name") > 0, "kernel_dialect op should own an 'op_name' attribute."); upper_op_name = op->attributes() .at("op_name") - .dyn_cast<::ir::StrAttribute>() + .dyn_cast() .AsString(); } @@ -154,14 +154,15 @@ GetEagerDeletionValues(ir::Block* block) { } for (size_t i = 0; i < op->num_results(); ++i) { - ir::Value output = op->result(i); + pir::Value output = op->result(i); if (output && CanBeDeleted(output)) { del_value_2_op[output] = op; } } } - std::unordered_map> eager_dels; + std::unordered_map> + eager_dels; for (auto& kv : del_value_2_op) { eager_dels[kv.second].insert(kv.first); } @@ -169,23 +170,23 @@ GetEagerDeletionValues(ir::Block* block) { return eager_dels; } -static std::unordered_map GetInplaceOps( - ir::Block* block) { +static std::unordered_map GetInplaceOps( + pir::Block* block) { const auto eager_dels = GetEagerDeletionValues(block); - std::unordered_map inplace_ops; + std::unordered_map inplace_ops; - std::unordered_set visited_values; - std::unordered_set reused_input_values; - std::unordered_set reused_output_values; + std::unordered_set visited_values; + std::unordered_set reused_input_values; + std::unordered_set reused_output_values; for (auto& op : *block) { for (size_t i = 0; i < op->num_operands(); ++i) { visited_values.insert(op->operand_source(i)); } - if (op->dialect()->name().compare( - paddle::dialect::PaddleKernelDialect::name()) != 0) { + if (op->dialect()->name().compare(paddle::dialect::KernelDialect::name()) != + 0) { VLOG(6) << op->name() << "is not a kernel_dialect op, inplace only support " "kernel_dialect operators"; @@ -197,13 +198,13 @@ static std::unordered_map GetInplaceOps( auto upper_op_attrs = op->attributes(); auto upper_op_name = - upper_op_attrs.at("op_name").dyn_cast<::ir::StrAttribute>().AsString(); + upper_op_attrs.at("op_name").dyn_cast().AsString(); VLOG(6) << "analyse op: " << upper_op_name; // NOTE(zhangbo): add_grad cpu kernel can't do inplace, for the reason shown // in the function: CommonElementwiseBroadcastBackward // (paddle/phi/kernels/funcs/elementwise_grad_base.h) - if ((upper_op_name == "pd.add_grad") && + if ((upper_op_name == "pd_op.add_grad") && (upper_op_attrs.at("kernel_key") .dyn_cast() .data() @@ -215,7 +216,7 @@ static std::unordered_map GetInplaceOps( } if (upper_op_attrs.count("is_inplace") != 0 && - upper_op_attrs.at("is_inplace").dyn_cast().data()) { + upper_op_attrs.at("is_inplace").dyn_cast().data()) { VLOG(6) << upper_op_name << " is already an inplace op."; for (size_t i = 0; i < op->num_operands(); ++i) { reused_input_values.insert(op->operand_source(i)); @@ -227,8 +228,8 @@ static std::unordered_map GetInplaceOps( continue; } - ir::OpInfo upper_inplace_op_info = - ir::IrContext::Instance()->GetRegisteredOpInfo(upper_op_name + "_"); + pir::OpInfo upper_inplace_op_info = + pir::IrContext::Instance()->GetRegisteredOpInfo(upper_op_name + "_"); if (eager_dels.count(op) == 0 || (!upper_inplace_op_info)) { VLOG(6) << upper_op_name @@ -300,12 +301,12 @@ static std::unordered_map GetInplaceOps( } } // namespace details -class InplacePass : public ir::Pass { +class InplacePass : public pir::Pass { public: - InplacePass() : ir::Pass("InplacePass", 3) {} + InplacePass() : pir::Pass("InplacePass", 3) {} - void Run(ir::Operation* op) override { - auto module_op = op->dyn_cast(); + void Run(pir::Operation* op) override { + auto module_op = op->dyn_cast(); IR_ENFORCE(module_op, "InplacePass should run on module op."); auto* block = module_op.block(); @@ -315,9 +316,9 @@ class InplacePass : public ir::Pass { VLOG(6) << "Do inplace for: " << kv.first->attributes() .at("op_name") - .dyn_cast<::ir::StrAttribute>() + .dyn_cast() .AsString(); - ir::Block::iterator insert_pos = + pir::Block::iterator insert_pos = std::find(block->begin(), block->end(), kv.first); IR_ENFORCE(insert_pos != block->end(), "Operator %s not found in block.", @@ -325,26 +326,26 @@ class InplacePass : public ir::Pass { kv.first->set_attribute( "op_name", - ir::StrAttribute::get(ir::IrContext::Instance(), kv.second)); + pir::StrAttribute::get(pir::IrContext::Instance(), kv.second)); kv.first->set_attribute( "is_inplace", - ir::BoolAttribute::get(ir::IrContext::Instance(), true)); + pir::BoolAttribute::get(pir::IrContext::Instance(), true)); } LOG_FIRST_N(INFO, 1) << "Apply inplace pass on lowering ::ir::Program to Kernel Dialect."; } - bool CanApplyOn(ir::Operation* op) const override { + bool CanApplyOn(pir::Operation* op) const override { return op->name() == "builtin.module" && op->num_regions() > 0; } }; -namespace ir { +namespace pir { -std::unique_ptr CreateInplacePass() { +std::unique_ptr CreateInplacePass() { return std::make_unique(); } -} // namespace ir +} // namespace pir REGISTER_IR_PASS(inplace, InplacePass); diff --git a/paddle/fluid/ir/transforms/inplace_pass.h b/paddle/fluid/pir/transforms/inplace_pass.h similarity index 90% rename from paddle/fluid/ir/transforms/inplace_pass.h rename to paddle/fluid/pir/transforms/inplace_pass.h index 028d6a9eb94e8..c6d540243edc9 100644 --- a/paddle/fluid/ir/transforms/inplace_pass.h +++ b/paddle/fluid/pir/transforms/inplace_pass.h @@ -15,12 +15,12 @@ #pragma once #include -#include "paddle/ir/core/dll_decl.h" +#include "paddle/pir/core/dll_decl.h" -namespace ir { +namespace pir { class Pass; std::unique_ptr CreateInplacePass(); -} // namespace ir +} // namespace pir diff --git a/paddle/fluid/ir/transforms/pd_op_to_kernel_pass.cc b/paddle/fluid/pir/transforms/pd_op_to_kernel_pass.cc similarity index 51% rename from paddle/fluid/ir/transforms/pd_op_to_kernel_pass.cc rename to paddle/fluid/pir/transforms/pd_op_to_kernel_pass.cc index 3555ebe354ab7..29b1df63a8562 100644 --- a/paddle/fluid/ir/transforms/pd_op_to_kernel_pass.cc +++ b/paddle/fluid/pir/transforms/pd_op_to_kernel_pass.cc @@ -14,19 +14,19 @@ #include -#include "paddle/fluid/ir/transforms/pd_op_to_kernel_pass.h" - -#include "paddle/fluid/ir/dialect/paddle_dialect/interface/op_yaml_info.h" -#include "paddle/fluid/ir/dialect/paddle_dialect/ir/pd_attribute.h" -#include "paddle/fluid/ir/dialect/paddle_dialect/ir/pd_dialect.h" -#include "paddle/fluid/ir/dialect/paddle_dialect/trait/inplace.h" -#include "paddle/fluid/ir/dialect/paddle_dialect/utils/op_yaml_info_parser.h" -#include "paddle/fluid/ir/dialect/paddle_dialect/utils/op_yaml_info_util.h" -#include "paddle/fluid/ir/dialect/paddle_dialect/utils/utils.h" -#include "paddle/fluid/ir/dialect/paddle_kernel_dialect/ir/kernel_attribute.h" -#include "paddle/fluid/ir/dialect/paddle_kernel_dialect/ir/kernel_dialect.h" -#include "paddle/fluid/ir/dialect/paddle_kernel_dialect/ir/kernel_op.h" -#include "paddle/fluid/ir/dialect/paddle_kernel_dialect/ir/kernel_type.h" +#include "paddle/fluid/pir/transforms/pd_op_to_kernel_pass.h" + +#include "paddle/fluid/pir/dialect/kernel/ir/kernel_attribute.h" +#include "paddle/fluid/pir/dialect/kernel/ir/kernel_dialect.h" +#include "paddle/fluid/pir/dialect/kernel/ir/kernel_op.h" +#include "paddle/fluid/pir/dialect/kernel/ir/kernel_type.h" +#include "paddle/fluid/pir/dialect/operator/interface/op_yaml_info.h" +#include "paddle/fluid/pir/dialect/operator/ir/op_attribute.h" +#include "paddle/fluid/pir/dialect/operator/ir/op_dialect.h" +#include "paddle/fluid/pir/dialect/operator/trait/inplace.h" +#include "paddle/fluid/pir/dialect/operator/utils/op_yaml_info_parser.h" +#include "paddle/fluid/pir/dialect/operator/utils/op_yaml_info_util.h" +#include "paddle/fluid/pir/dialect/operator/utils/utils.h" #include "paddle/fluid/platform/place.h" #include "paddle/phi/api/lib/data_transform.h" #include "paddle/phi/api/lib/kernel_dispatch.h" @@ -51,177 +51,23 @@ std::unordered_map Str2PhiDataType = { }; const std::unordered_set UnchangeOutputOps = { - "pd.data", + "pd_op.data", "builtin.combine", "builtin.slice", "builtin.split", - "pd.feed", - "pd.fetch", + "pd_op.feed", + "pd_op.fetch", "builtin.set_parameter", "builtin.get_parameter", - "pd.shadow_output"}; + "pd_op.shadow_output"}; -const std::unordered_set SpecialOpList = { - "builtin.combine", "builtin.slice", "builtin.split"}; - -ir::OpResult GetNewInput( - const ir::Value cur_in, - const std::unordered_map& map_value_pair, - const int index, - const std::string op_name) { - PADDLE_ENFORCE_EQ( - map_value_pair.count(cur_in), - true, - phi::errors::PreconditionNotMet( - "[%d]'s input of [%s] op MUST be in map pair", index, op_name)); - auto new_in = map_value_pair.at(cur_in); - return new_in; -} - -void DealWithSpecialBuiltinOps( - ir::Operation* op_item, - ir::Program* program, - std::unordered_map* map_op_pair, - std::unordered_map* map_value_pair, - ir::IrContext* ctx) { - if (op_item->name() == "builtin.combine") { - std::vector out_places; - // Copy op inputs - std::vector vec_inputs; - std::vector vec_inner_types; - if (op_item->num_operands() > 0) { - for (size_t i = 0; i < op_item->num_operands(); ++i) { - auto cur_in = op_item->operand_source(i); - if (!cur_in) { - vec_inputs.emplace_back(); - continue; - } - auto new_in = GetNewInput(cur_in, *map_value_pair, i, op_item->name()); - vec_inputs.push_back(new_in); - vec_inner_types.push_back(new_in.type()); - if (new_in.type().isa()) { - out_places.push_back( - new_in.type() - .dyn_cast() - .place()); - } else if (new_in.type() - .isa()) { - out_places.push_back( - new_in.type() - .dyn_cast() - .place()); - } else { - PADDLE_THROW(phi::errors::Unimplemented( - "only support dense tensor type for now")); - } - } - } - // Copy op output type - std::vector op_output_types; - ir::Type t1 = ir::VectorType::get(ctx, vec_inner_types); - op_output_types.push_back(t1); - - // Get op info - ir::OpInfo op_info = ctx->GetRegisteredOpInfo(op_item->name()); - // Generate new op - ir::Operation* op = ir::Operation::Create( - vec_inputs, op_item->attributes(), op_output_types, op_info); - program->block()->push_back(op); - (*map_op_pair)[op_item] = op; - // only deal with single output - if (op_item->num_results() > 0) { - for (size_t i = 0; i < op_item->num_results(); ++i) { - (*map_value_pair)[op_item->result(i)] = op->result(i); - } - } - } - - if (op_item->name() == "builtin.slice") { - std::vector vec_inputs; - std::vector op_output_types; - if (op_item->num_operands() > 0) { - for (size_t i = 0; i < op_item->num_operands(); ++i) { - auto cur_in = op_item->operand_source(i); - if (!cur_in) { - vec_inputs.emplace_back(); - continue; - } - auto new_in = GetNewInput(cur_in, *map_value_pair, i, op_item->name()); - vec_inputs.push_back(new_in); - if (new_in.type().isa()) { - auto vec_types = new_in.type().dyn_cast().data(); - auto index = op_item->attributes() - .at("index") - .dyn_cast() - .data(); - op_output_types.push_back(vec_types[index]); - } else { - PADDLE_THROW( - phi::errors::Unimplemented("only support vector type for now")); - } - } - } - - // Get op info - ir::OpInfo op_info = ctx->GetRegisteredOpInfo(op_item->name()); - // Generate new op - ir::Operation* op = ir::Operation::Create( - vec_inputs, op_item->attributes(), op_output_types, op_info); - program->block()->push_back(op); - (*map_op_pair)[op_item] = op; - // only deal with single output - if (op_item->num_results() > 0) { - for (size_t i = 0; i < op_item->num_results(); ++i) { - (*map_value_pair)[op_item->result(i)] = op->result(i); - } - } - } - - if (op_item->name() == "builtin.split") { - std::vector out_places(op_item->num_results()); - // Copy op inputs - std::vector vec_inputs; - std::vector op_output_types; - if (op_item->num_operands() > 0) { - for (size_t i = 0; i < op_item->num_operands(); ++i) { - auto cur_in = op_item->operand_source(i); - if (!cur_in) { - vec_inputs.emplace_back(); - continue; - } - auto new_in = GetNewInput(cur_in, *map_value_pair, i, op_item->name()); - vec_inputs.push_back(new_in); - - if (new_in.type().isa()) { - auto vec_types = new_in.type().dyn_cast().data(); - for (uint64_t idx = 0; idx < vec_types.size(); idx++) { - op_output_types.push_back(vec_types[idx]); - } - } else { - PADDLE_THROW( - phi::errors::Unimplemented("only support vector type for now")); - } - } - } - - // Get op info - ir::OpInfo op_info = ctx->GetRegisteredOpInfo(op_item->name()); - // Generate new op - ir::Operation* op = ir::Operation::Create( - vec_inputs, op_item->attributes(), op_output_types, op_info); - program->block()->push_back(op); - (*map_op_pair)[op_item] = op; - // only deal with single output - if (op_item->num_results() > 0) { - for (size_t i = 0; i < op_item->num_results(); ++i) { - (*map_value_pair)[op_item->result(i)] = op->result(i); - } - } - } - VLOG(6) << "Deep copy a new builtin op: " << op_item->name(); -} +const std::unordered_set SpecialLowerOps = { + "builtin.combine", + "builtin.slice", + "builtin.split", +}; -bool NeedFallBackCpu(const ir::Operation* op, +bool NeedFallBackCpu(const pir::Operation* op, const std::string& kernel_fn_name, const phi::KernelKey& kernel_key) { if (UnchangeOutputOps.count(op->name())) { @@ -255,7 +101,7 @@ bool NeedFallBackCpu(const ir::Operation* op, phi::Backend GetDstBackend(const std::string& op_name, phi::Place place, - OpYamlInfoParser* op_yaml_info_parser, + const OpYamlInfoParser* op_yaml_info_parser, phi::Backend kernel_def_backend, size_t input_index) { if (op_name == "builtin.set_parameter" && @@ -275,14 +121,16 @@ phi::Backend GetDstBackend(const std::string& op_name, return dst_backend; } -bool NeedFallBackFromGPUDNN2GPU(ir::Operation* op, +bool NeedFallBackFromGPUDNN2GPU(pir::Operation* op, const phi::KernelKey kernel_key) { // NOTE(phlrain): keep the same kernel select strategy with // GetExepectKernelKey - if (op->name() == "pd.pool2d" || op->name() == "pd.pool2d_grad") { + if (op->name() == "pd_op.pool2d" || op->name() == "pd_op.pool2d_grad") { if (kernel_key.backend() == phi::Backend::GPUDNN && - (op->attributes().at("adaptive").dyn_cast().data() == - true)) { + (op->attributes() + .at("adaptive") + .dyn_cast() + .data() == true)) { return true; } } @@ -290,26 +138,26 @@ bool NeedFallBackFromGPUDNN2GPU(ir::Operation* op, return false; } -std::set GetSkipFeedNames(ir::Block* block) { +std::set GetSkipFeedNames(pir::Block* block) { std::set data_op_names; for (auto op_item : *block) { - if (op_item->name() == "pd.data") { + if (op_item->name() == "pd_op.data") { data_op_names.insert(op_item->attributes() .at("name") - .dyn_cast() + .dyn_cast() .AsString()); } } return data_op_names; } -bool SkipFeedOp(ir::Operation* op, const std::set& feed_names) { +bool SkipFeedOp(pir::Operation* op, const std::set& feed_names) { return feed_names.count( - op->attributes().at("name").dyn_cast().AsString()); + op->attributes().at("name").dyn_cast().AsString()); } std::vector> GetFakeTensorList( - ir::Value new_input_tmp) { + pir::Value new_input_tmp) { std::vector> vec_res; auto input_type = new_input_tmp.type(); @@ -356,8 +204,8 @@ std::vector> GetFakeTensorList( } else if (input_type.isa()) { vec_res.push_back(build_fake_selected_rows( input_type.dyn_cast())); - } else if (input_type.isa()) { - auto vec_inner_types = input_type.dyn_cast().data(); + } else if (input_type.isa()) { + auto vec_inner_types = input_type.dyn_cast().data(); for (size_t i = 0; i < vec_inner_types.size(); ++i) { if (vec_inner_types[i].isa()) { vec_res.push_back(build_fake_dense_tensor( @@ -372,29 +220,29 @@ std::vector> GetFakeTensorList( return vec_res; } -ir::OpResult AddPlaceTransferOp(ir::OpResult in, - ir::Type out_type, - const phi::Place& src_place, - const phi::Place& dst_place, - const phi::KernelKey& kernel_key, - ir::Program* program) { - ir::IrContext* ctx = ir::IrContext::Instance(); +pir::OpResult AddPlaceTransferOp(pir::OpResult in, + pir::Type out_type, + const phi::Place& src_place, + const phi::Place& dst_place, + const phi::KernelKey& kernel_key, + pir::Program* program) { + pir::IrContext* ctx = pir::IrContext::Instance(); std::string op_name = paddle::dialect::PhiKernelOp::name(); - ir::OpInfo op_info = ctx->GetRegisteredOpInfo(op_name); + pir::OpInfo op_info = ctx->GetRegisteredOpInfo(op_name); if ((src_place.GetType() == phi::AllocationType::CPU) && (dst_place.GetType() == phi::AllocationType::GPU)) { auto copy_kernel_key = kernel_key; copy_kernel_key.set_backend(phi::Backend::GPU); - std::unordered_map op_attribute{ - {"op_name", ir::StrAttribute::get(ctx, "pd.memcpy_h2d")}, - {"kernel_name", ir::StrAttribute::get(ctx, "memcpy_h2d")}, + std::unordered_map op_attribute{ + {"op_name", pir::StrAttribute::get(ctx, "pd_op.memcpy_h2d")}, + {"kernel_name", pir::StrAttribute::get(ctx, "memcpy_h2d")}, {"kernel_key", dialect::KernelAttribute::get(ctx, copy_kernel_key)}, - {"dst_place_type", ir::Int32Attribute::get(ctx, 1)}}; + {"dst_place_type", pir::Int32Attribute::get(ctx, 1)}}; - ir::Operation* op = - ir::Operation::Create({in}, op_attribute, {out_type}, op_info); + pir::Operation* op = + pir::Operation::Create({in}, op_attribute, {out_type}, op_info); if (in.GetDefiningOp()->HasAttribute(kAttrIsPersisable)) { op->set_attribute(kAttrIsPersisable, @@ -409,14 +257,14 @@ ir::OpResult AddPlaceTransferOp(ir::OpResult in, (dst_place.GetType() == phi::AllocationType::CPU)) { auto copy_kernel_key = kernel_key; copy_kernel_key.set_backend(phi::Backend::GPU); - std::unordered_map op_attribute{ - {"op_name", ir::StrAttribute::get(ctx, "pd.memcpy_d2h")}, - {"kernel_name", ir::StrAttribute::get(ctx, "memcpy_d2h")}, + std::unordered_map op_attribute{ + {"op_name", pir::StrAttribute::get(ctx, "pd_op.memcpy_d2h")}, + {"kernel_name", pir::StrAttribute::get(ctx, "memcpy_d2h")}, {"kernel_key", dialect::KernelAttribute::get(ctx, copy_kernel_key)}, - {"dst_place_type", ir::Int32Attribute::get(ctx, 0)}}; + {"dst_place_type", pir::Int32Attribute::get(ctx, 0)}}; - ir::Operation* op = - ir::Operation::Create({in}, op_attribute, {out_type}, op_info); + pir::Operation* op = + pir::Operation::Create({in}, op_attribute, {out_type}, op_info); program->block()->push_back(op); @@ -428,10 +276,10 @@ ir::OpResult AddPlaceTransferOp(ir::OpResult in, } } -ir::Type BuildOutputType(ir::Type type, - const phi::Place& place, - phi::DataType data_type, - ir::IrContext* ctx) { +pir::Type BuildOutputType(pir::Type type, + const phi::Place& place, + phi::DataType data_type, + pir::IrContext* ctx) { if (type.isa()) { auto dense_tensor_type = type.dyn_cast(); auto out_dtype = dense_tensor_type.dtype(); @@ -473,8 +321,8 @@ ir::Type BuildOutputType(ir::Type type, } phi::DataType GetKernelDataTypeByYamlInfo( - const ir::Operation* op, - const std::unordered_map& map_value_pair, + const pir::Operation* op, + const std::unordered_map& map_value_pair, const dialect::OpYamlInfoParser* op_info_parser) { auto& attr_map = op->attributes(); auto& data_type_info = op_info_parser->OpRuntimeInfo().kernel_key_dtype; @@ -495,8 +343,8 @@ phi::DataType GetKernelDataTypeByYamlInfo( if (type.isa()) { kernel_data_type = TransToPhiDataType( type.dyn_cast().dtype()); - } else if (type.isa()) { - auto vec_data = type.dyn_cast().data(); + } else if (type.isa()) { + auto vec_data = type.dyn_cast().data(); if (vec_data.empty()) { kernel_data_type = phi::DataType::UNDEFINED; } else { @@ -547,8 +395,8 @@ phi::DataType GetKernelDataTypeByYamlInfo( } phi::Backend GetKernelBackendByYamlInfo( - const ir::Operation* op, - const std::unordered_map& map_value_pair, + const pir::Operation* op, + const std::unordered_map& map_value_pair, const dialect::OpYamlInfoParser* op_info_parser) { auto& attr_map = op->attributes(); auto& backend_info = op_info_parser->OpRuntimeInfo().kernel_key_backend; @@ -565,8 +413,8 @@ phi::Backend GetKernelBackendByYamlInfo( if (type.isa()) { kernel_backend = paddle::experimental::ParseBackend( type.dyn_cast().place()); - } else if (type.isa()) { - auto vec_data = type.dyn_cast().data(); + } else if (type.isa()) { + auto vec_data = type.dyn_cast().data(); if (vec_data.empty()) { kernel_backend = phi::Backend::UNDEFINED; } else { @@ -617,11 +465,12 @@ phi::Backend GetKernelBackendByYamlInfo( } phi::KernelKey GetKernelKey( - ir::Operation* op, + pir::Operation* op, const phi::Place& place, - const std::unordered_map& map_value_pair, + const std::string& kernel_fn_str, + const std::unordered_map& map_value_pair, dialect::OpYamlInfoParser* op_info_parser = nullptr) { - if (op->name() == "pd.feed") { + if (op->name() == "pd_op.feed") { // NOTE, for now feed op don't need a kernel, so the data type from Op // Result the next op use base program datatype return {phi::Backend::CPU, @@ -630,7 +479,7 @@ phi::KernelKey GetKernelKey( op->result(0).type().dyn_cast().dtype())}; } - if (op->name() == "pd.data") { + if (op->name() == "pd_op.data") { // NOTE, for now feed op don't need a kernel, so the data type from Op // Result the next op use base program datatype auto data_place = @@ -644,6 +493,14 @@ phi::KernelKey GetKernelKey( op->result(0).type().dyn_cast().dtype())}; } + if (op->name() == "pd_op.seed") { + auto backend = paddle::experimental::ParseBackend(place); + return {backend, + phi::DataLayout::ANY, + TransToPhiDataType( + op->result(0).type().dyn_cast().dtype())}; + } + phi::Backend kernel_backend = phi::Backend::UNDEFINED; phi::DataLayout kernel_layout = phi::DataLayout::UNDEFINED; phi::DataType kernel_data_type = phi::DataType::UNDEFINED; @@ -659,14 +516,14 @@ phi::KernelKey GetKernelKey( GetKernelBackendByYamlInfo(op, map_value_pair, op_info_parser); // parse all the input tensor - if (tensor_input_number == 0 || op->name() == "pd.full_") { + if (tensor_input_number == 0 || op->name() == "pd_op.full_") { // all the information have to get from attribute and context - if (op->name() == "pd.uniform") { + if (op->name() == "pd_op.uniform") { // try to process uniform, use shape to determin backend // TODO(phlrain): shuold support other initilize op auto define_op = op->operand_source(0).GetDefiningOp(); - if (define_op->name() == "pd.full_int_array") { + if (define_op->name() == "pd_op.full_int_array") { auto shape = define_op->attributes() .at("value") .dyn_cast() @@ -714,7 +571,7 @@ phi::KernelKey GetKernelKey( // don't know how to select the kernel in the next of op that // uses data op outout as inputs. So, we need set kernel backend // manually. - if (op->operand_source(i).GetDefiningOp()->name() == "pd.data") { + if (op->operand_source(i).GetDefiningOp()->name() == "pd_op.data") { auto data_op = op->operand_source(i).GetDefiningOp(); auto data_place = data_op->attributes() .at("place") @@ -733,7 +590,7 @@ phi::KernelKey GetKernelKey( auto combine_op = op->operand_source(i).GetDefiningOp(); for (size_t j = 0; j < combine_op->num_operands(); ++j) { if (combine_op->operand_source(j).GetDefiningOp()->name() == - "pd.data") { + "pd_op.data") { auto data_op = combine_op->operand_source(j).GetDefiningOp(); auto data_place = data_op->attributes() .at("place") @@ -774,409 +631,587 @@ phi::KernelKey GetKernelKey( } phi::KernelKey res(kernel_backend, kernel_layout, kernel_data_type); - return res; -} -std::unique_ptr PdOpLowerToKernelPass(ir::Program* prog, - phi::Place place) { - if (VLOG_IS_ON(2)) { - std::stringstream ss; - prog->Print(ss); - VLOG(2) << "Program after lowering to kernel pass : " << ss.str(); + if (op->name() == "pd_op.load_combine") { + res.set_dtype(phi::DataType::FLOAT32); + } + if (NeedFallBackCpu((op), kernel_fn_str, res)) { + res.set_backend(phi::Backend::CPU); } - auto program = std::make_unique(ir::IrContext::Instance()); - - auto block = prog->block(); - - ir::IrContext* ctx = ir::IrContext::Instance(); - ctx->GetOrRegisterDialect(); - ctx->GetOrRegisterDialect(); - - std::unordered_map map_op_pair; - std::unordered_map map_value_pair; - - std::string phi_kernel_op_name = paddle::dialect::PhiKernelOp::name(); - ir::OpInfo phi_kernel_op_info = ctx->GetRegisteredOpInfo(phi_kernel_op_name); - std::string legacy_kernel_op_name = paddle::dialect::LegacyKernelOp::name(); - ir::OpInfo legacy_kernel_op_info = - ctx->GetRegisteredOpInfo(legacy_kernel_op_name); + if (NeedFallBackFromGPUDNN2GPU(op, res)) { + res.set_backend(phi::Backend::GPU); + } - auto skip_feed_names = GetSkipFeedNames(block); + return res; +} - for (auto op_item : *block) { - VLOG(6) << "op name " << op_item->name(); - if ((op_item->name() == "pd.feed") && - SkipFeedOp(op_item, skip_feed_names)) { - continue; - } +pir::OpResult GetNewInput( + const pir::Value cur_in, + const std::unordered_map& map_value_pair, + const int index, + const std::string op_name) { + PADDLE_ENFORCE_EQ( + map_value_pair.count(cur_in), + true, + phi::errors::PreconditionNotMet( + "[%d]'s input of [%s] op MUST be in map pair", index, op_name)); + auto new_in = map_value_pair.at(cur_in); + return new_in; +} - if (SpecialOpList.count(op_item->name())) { - DealWithSpecialBuiltinOps( - op_item, program.get(), &map_op_pair, &map_value_pair, ctx); - continue; +void HandleForSpecialOp( + pir::Operation* op_item, + pir::Program* program, + pir::IrContext* ctx, + std::unordered_map* map_op_pair, + std::unordered_map* map_value_pair) { + std::vector vec_inputs; + std::vector op_output_types; + if (op_item->name() == "builtin.combine") { + // Copy op inputs + std::vector vec_inner_types; + if (op_item->num_operands() > 0) { + for (size_t i = 0; i < op_item->num_operands(); ++i) { + auto cur_in = op_item->operand_source(i); + if (!cur_in) { + vec_inputs.emplace_back(); + continue; + } + auto new_in = GetNewInput(cur_in, *map_value_pair, i, op_item->name()); + vec_inputs.push_back(new_in); + vec_inner_types.push_back(new_in.type()); + } } + // Copy op output type - // Lower from PaddleDialect to KernelDialect - paddle::dialect::OpYamlInfoInterface op_info_interface = - op_item->dyn_cast(); - - std::unique_ptr op_info_parser(nullptr); - if (op_info_interface) { - op_info_parser = - std::make_unique(op_info_interface.GetOpInfo()); - } + pir::Type t1 = pir::VectorType::get(ctx, vec_inner_types); + op_output_types.push_back(t1); + } - std::string kernel_fn_str; - if (op_info_parser != nullptr) { - kernel_fn_str = op_info_parser->OpRuntimeInfo().kernel_func[0]; - } + if (op_item->name() == "builtin.slice") { + if (op_item->num_operands() > 0) { + for (size_t i = 0; i < op_item->num_operands(); ++i) { + auto cur_in = op_item->operand_source(i); + if (!cur_in) { + vec_inputs.emplace_back(); + continue; + } + auto new_in = GetNewInput(cur_in, *map_value_pair, i, op_item->name()); + vec_inputs.push_back(new_in); - if (op_item->name() == "pd.add_n_" || - op_item->name() == "pd.add_n_with_kernel") { - if (op_item->result(0).type().isa()) { - kernel_fn_str = "add_n_sr"; + if (new_in.type().isa()) { + auto vec_types = new_in.type().dyn_cast().data(); + auto index = op_item->attributes() + .at("index") + .dyn_cast() + .data(); + op_output_types.push_back(vec_types[index]); + } else { + PADDLE_THROW( + phi::errors::Unimplemented("only support vector type for now")); + } } } + } - auto kernel_key = - GetKernelKey(op_item, place, map_value_pair, op_info_parser.get()); - VLOG(6) << "kernel type " << kernel_key; + if (op_item->name() == "builtin.split") { + if (op_item->num_operands() > 0) { + for (size_t i = 0; i < op_item->num_operands(); ++i) { + auto cur_in = op_item->operand_source(i); + if (!cur_in) { + vec_inputs.emplace_back(); + continue; + } + auto new_in = GetNewInput(cur_in, *map_value_pair, i, op_item->name()); + vec_inputs.push_back(new_in); - if (op_item->name() == "pd.load_combine") { - kernel_key.set_dtype(phi::DataType::FLOAT32); - } - if (NeedFallBackCpu((op_item), kernel_fn_str, kernel_key)) { - kernel_key.set_backend(phi::Backend::CPU); + if (new_in.type().isa()) { + auto vec_types = new_in.type().dyn_cast().data(); + for (uint64_t idx = 0; idx < vec_types.size(); idx++) { + op_output_types.push_back(vec_types[idx]); + } + } else { + PADDLE_THROW( + phi::errors::Unimplemented("only support vector type for now")); + } + } } + } - if (NeedFallBackFromGPUDNN2GPU(op_item, kernel_key)) { - kernel_key.set_backend(phi::Backend::GPU); + pir::OpInfo op_info = ctx->GetRegisteredOpInfo(op_item->name()); + // Generate new op + pir::Operation* op = pir::Operation::Create( + vec_inputs, op_item->attributes(), op_output_types, op_info); + program->block()->push_back(op); + (*map_op_pair)[op_item] = op; + // only deal with single output + if (op_item->num_results() > 0) { + for (size_t i = 0; i < op_item->num_results(); ++i) { + (*map_value_pair)[op_item->result(i)] = op->result(i); } + } + VLOG(6) << "Deep copy a new builtin op: " << op_item->name(); +} - // only for single output - // need update new kernel key layout and data tyep - - std::vector op_output_types; - if (op_item->num_results() > 0) { - auto phi_kernel = phi::KernelFactory::Instance().SelectKernelWithGPUDNN( - kernel_fn_str, kernel_key); - auto args_def = phi_kernel.args_def(); - auto output_defs = args_def.output_defs(); - if (!UnchangeOutputOps.count(op_item->name()) && - !IsLegacyOp(op_item->name())) { - PADDLE_ENFORCE_EQ( - op_item->num_results(), - output_defs.size(), - phi::errors::PreconditionNotMet( - "op [%s] kernel output args defs should equal op outputs", - op_item->name())); - } +std::vector BuildOpOutputType(pir::Operation* op_item, + const std::string& kernel_fn_str, + const phi::KernelKey& kernel_key, + pir::IrContext* ctx) { + if (op_item->num_results() == 0) { + return {}; + } + std::vector op_output_types; + auto phi_kernel = phi::KernelFactory::Instance().SelectKernelWithGPUDNN( + kernel_fn_str, kernel_key); + auto args_def = phi_kernel.args_def(); + auto output_defs = args_def.output_defs(); + if (!UnchangeOutputOps.count(op_item->name()) && + !IsLegacyOp(op_item->name())) { + PADDLE_ENFORCE_EQ( + op_item->num_results(), + output_defs.size(), + phi::errors::PreconditionNotMet( + "op [%s] kernel output args defs should equal op outputs", + op_item->name())); + } - for (size_t i = 0; i < op_item->num_results(); ++i) { - phi::Place out_place = phi::TransToPhiPlace(kernel_key.backend()); + for (size_t i = 0; i < op_item->num_results(); ++i) { + phi::Place out_place = phi::TransToPhiPlace(kernel_key.backend()); - phi::DataType out_phi_dtype = phi::DataType::UNDEFINED; - if ((!UnchangeOutputOps.count(op_item->name())) && - (!IsLegacyOp(op_item->name())) && phi_kernel.IsValid()) { - out_place = phi::TransToPhiPlace(output_defs[i].backend); - out_phi_dtype = output_defs[i].dtype; - } + phi::DataType out_phi_dtype = phi::DataType::UNDEFINED; + if ((!UnchangeOutputOps.count(op_item->name())) && + (!IsLegacyOp(op_item->name())) && phi_kernel.IsValid()) { + out_place = phi::TransToPhiPlace(output_defs[i].backend); + out_phi_dtype = output_defs[i].dtype; + } - auto result_type = op_item->result(i).type(); - if (!result_type) { - op_output_types.push_back(result_type); - } else if (result_type.isa() || - result_type.isa()) { - op_output_types.push_back( - BuildOutputType(result_type, out_place, out_phi_dtype, ctx)); - } else if (result_type.isa()) { - std::vector vec_inner_types; - auto base_types = result_type.dyn_cast().data(); - for (auto& base_type : base_types) { - if (base_type) { - if (base_type.isa()) { - vec_inner_types.push_back( - BuildOutputType(base_type, out_place, out_phi_dtype, ctx)); - } else { - PADDLE_THROW(phi::errors::Unimplemented( - "only support dense tensor in vector type for now")); - } - } else { - // NOTE(phlrain), kernel not support a nullptr in output - ir::Type fp32_dtype = ir::Float32Type::get(ctx); - phi::DDim dims = {}; - phi::DataLayout data_layout = phi::DataLayout::NCHW; - phi::LoD lod = {{}}; - size_t offset = 0; - auto dense_tensor_dtype = paddle::dialect::DenseTensorType::get( - ctx, fp32_dtype, dims, data_layout, lod, offset); - auto allocated_dense_tensor_dtype = - paddle::dialect::AllocatedDenseTensorType::get( - ctx, out_place, dense_tensor_dtype); - vec_inner_types.push_back(allocated_dense_tensor_dtype); - } + auto result_type = op_item->result(i).type(); + if (!result_type) { + op_output_types.push_back(result_type); + } else if (result_type.isa() || + result_type.isa()) { + op_output_types.push_back( + BuildOutputType(result_type, out_place, out_phi_dtype, ctx)); + } else if (result_type.isa()) { + std::vector vec_inner_types; + auto base_types = result_type.dyn_cast().data(); + for (auto& base_type : base_types) { + if (base_type) { + if (base_type.isa()) { + vec_inner_types.push_back( + BuildOutputType(base_type, out_place, out_phi_dtype, ctx)); + } else { + PADDLE_THROW(phi::errors::Unimplemented( + "only support dense tensor in vector type for now")); } - - ir::Type t1 = ir::VectorType::get(ctx, vec_inner_types); - op_output_types.push_back(t1); } else { - PADDLE_THROW(phi::errors::Unimplemented( - "Result type only support DenseTensorType, SelectedRowType and " - "VectorType")); + // NOTE(phlrain), kernel not support a nullptr in output + pir::Type fp32_dtype = pir::Float32Type::get(ctx); + phi::DDim dims = {}; + phi::DataLayout data_layout = phi::DataLayout::NCHW; + phi::LoD lod = {{}}; + size_t offset = 0; + auto dense_tensor_dtype = paddle::dialect::DenseTensorType::get( + ctx, fp32_dtype, dims, data_layout, lod, offset); + auto allocated_dense_tensor_dtype = + paddle::dialect::AllocatedDenseTensorType::get( + ctx, out_place, dense_tensor_dtype); + vec_inner_types.push_back(allocated_dense_tensor_dtype); } } + + pir::Type t1 = pir::VectorType::get(ctx, vec_inner_types); + op_output_types.push_back(t1); + } else { + PADDLE_THROW(phi::errors::Unimplemented( + "Result type only support DenseTensorType, SelectedRowType and " + "VectorType")); } + } - // constuct input - std::vector vec_inputs; - if (op_item->num_operands() > 0) { - for (size_t i = 0; i < op_item->num_operands(); ++i) { - auto cur_in = op_item->operand_source(i); - if (!cur_in) { - vec_inputs.emplace_back(); - continue; + return op_output_types; +} + +std::vector BuildOpInputList( + pir::Operation* op_item, + const std::string& kernel_fn_str, + const phi::KernelKey& kernel_key, + const phi::Place place, + const OpYamlInfoParser* op_info_parser, + pir::IrContext* ctx, + std::unordered_map* map_op_pair, + std::unordered_map* map_value_pair, + pir::Program* program) { + if (op_item->num_operands() == 0) { + return {}; + } + + std::vector vec_inputs; + + for (size_t i = 0; i < op_item->num_operands(); ++i) { + auto cur_in = op_item->operand_source(i); + if (!cur_in) { + vec_inputs.emplace_back(); + continue; + } + PADDLE_ENFORCE_EQ( + map_value_pair->count(cur_in), + true, + phi::errors::PreconditionNotMet( + "[%d]'s input of [%s] op MUST in map pair", i, op_item->name())); + auto new_in = map_value_pair->at(cur_in); + + auto new_in_type = new_in.type(); + + auto& kernel = phi::KernelFactory::Instance().SelectKernelWithGPUDNN( + kernel_fn_str, kernel_key); + + bool check_place_transfer = + (op_item->name() == "builtin.set_parameter") || + (kernel.IsValid() && (!UnchangeOutputOps.count(op_item->name()))); + + if (check_place_transfer) { + if (new_in_type.isa()) { + // allocated type + auto in_place = + new_in_type.dyn_cast().place(); + + // get input args def type + auto args_def = kernel.args_def(); + auto input_defs = args_def.input_defs(); + + auto dst_backend = GetDstBackend(op_item->name(), + place, + op_info_parser, + kernel.InputAt(i).backend, + i); + + bool need_trans = + (in_place.GetType() != phi::AllocationType::UNDEFINED) && + (paddle::experimental::NeedTransformPlace( + in_place, dst_backend, {})); + if (need_trans) { + VLOG(6) << "need trans from " << in_place << " to " + << kernel_key.backend(); + // build memcopy op + auto out_place = phi::TransToPhiPlace(dst_backend); + auto new_in_alloc_type = + new_in_type.dyn_cast(); + auto out_type = dialect::AllocatedDenseTensorType::get( + ctx, + out_place, + new_in_alloc_type.dtype(), + new_in_alloc_type.dims(), + new_in_alloc_type.data_layout(), + new_in_alloc_type.lod(), + new_in_alloc_type.offset()); + new_in = AddPlaceTransferOp( + new_in, out_type, in_place, out_place, kernel_key, program); } - PADDLE_ENFORCE_EQ(map_value_pair.count(cur_in), - true, - phi::errors::PreconditionNotMet( - "[%d]'s input of [%s] op MUST in map pair", - i, - op_item->name())); - auto new_in = map_value_pair.at(cur_in); - - auto new_in_type = new_in.type(); - - auto& kernel = phi::KernelFactory::Instance().SelectKernelWithGPUDNN( - kernel_fn_str, kernel_key); - - bool check_place_transfer = - (op_item->name() == "builtin.set_parameter") || - (kernel.IsValid() && (!UnchangeOutputOps.count(op_item->name()))); - - if (check_place_transfer) { - if (new_in_type.isa()) { - // allocated type - auto in_place = - new_in_type.dyn_cast() - .place(); + } else if (new_in_type.isa()) { + // [ todo need update here, support combine data transfomer] + // deal with pre combine op + auto pre_define_op = cur_in.GetDefiningOp(); + + if (pre_define_op->name() == "builtin.combine") { + std::vector inner_inputs; + std::vector types_in_vec; + bool is_trans = false; + for (size_t j = 0; j < pre_define_op->num_operands(); ++j) { + auto in_i = map_value_pair->at(pre_define_op->operand_source(j)); + auto in_i_type = in_i.type(); + phi::Place place; + if (in_i_type.isa()) { + place = in_i_type.dyn_cast() + .place(); + } else if (in_i_type.isa()) { + place = in_i_type.dyn_cast() + .place(); + } else { + PADDLE_THROW(phi::errors::Unimplemented( + "builtin.combine Input type only support " + "VectorType and " + "VectorType")); + } // get input args def type auto args_def = kernel.args_def(); auto input_defs = args_def.input_defs(); - auto dst_backend = GetDstBackend(op_item->name(), - place, - op_info_parser.get(), - kernel.InputAt(i).backend, - i); - bool need_trans = - (in_place.GetType() != phi::AllocationType::UNDEFINED) && + (place.GetType() != phi::AllocationType::UNDEFINED) && + (op_info_parser != nullptr && + !op_info_parser->IsTensorAttribute(i)) && (paddle::experimental::NeedTransformPlace( - in_place, dst_backend, {})); + place, kernel.InputAt(i).backend, {})); if (need_trans) { - VLOG(6) << "need trans from " << in_place << " to " + VLOG(6) << "need trans from " << place << " to " << kernel_key.backend(); // build memcopy op - auto out_place = phi::TransToPhiPlace(dst_backend); - auto new_in_alloc_type = - new_in_type.dyn_cast(); - auto out_type = dialect::AllocatedDenseTensorType::get( - ctx, - out_place, - new_in_alloc_type.dtype(), - new_in_alloc_type.dims(), - new_in_alloc_type.data_layout(), - new_in_alloc_type.lod(), - new_in_alloc_type.offset()); - new_in = AddPlaceTransferOp(new_in, - out_type, - in_place, - out_place, - kernel_key, - program.get()); - } - } else if (new_in_type.isa()) { - // [ todo need update here, support combine data transfomer] - // deal with pre combine op - auto pre_define_op = cur_in.GetDefiningOp(); - - if (pre_define_op->name() == "builtin.combine") { - std::vector inner_inputs; - std::vector types_in_vec; - bool is_trans = false; - for (size_t j = 0; j < pre_define_op->num_operands(); ++j) { - auto in_i = map_value_pair.at(pre_define_op->operand_source(j)); - auto in_i_type = in_i.type(); - phi::Place place; - if (in_i_type.isa()) { - place = - in_i_type.dyn_cast() - .place(); - } else if (in_i_type - .isa()) { - place = - in_i_type.dyn_cast() - .place(); - } else { - PADDLE_THROW(phi::errors::Unimplemented( - "builtin.combine Input type only support " - "VectorType and " - "VectorType")); - } - - // get input args def type - auto args_def = kernel.args_def(); - auto input_defs = args_def.input_defs(); - - bool need_trans = - (place.GetType() != phi::AllocationType::UNDEFINED) && - (op_info_parser != nullptr && - !op_info_parser->IsTensorAttribute(i)) && - (paddle::experimental::NeedTransformPlace( - place, kernel.InputAt(i).backend, {})); - if (need_trans) { - VLOG(6) << "need trans from " << place << " to " - << kernel_key.backend(); - // build memcopy op - auto out_place = - phi::TransToPhiPlace(kernel.InputAt(i).backend); - - ir::Type out_type; - if (in_i_type.isa()) { - out_type = dialect::AllocatedDenseTensorType::get( - ctx, - out_place, - pre_define_op->operand_source(j) - .type() - .dyn_cast()); - } else if (in_i_type - .isa()) { - out_type = dialect::AllocatedSelectedRowsType::get( - ctx, - out_place, - pre_define_op->operand_source(j) - .type() - .dyn_cast()); - } else { - PADDLE_THROW(phi::errors::Unimplemented( - "builtin.combine Input type only support " - "VectorType and " - "VectorType")); - } - - in_i = AddPlaceTransferOp(in_i, - out_type, - place, - out_place, - kernel_key, - program.get()); - - is_trans = true; - } - - inner_inputs.push_back(in_i); - types_in_vec.push_back(in_i.type()); - } - if (is_trans) { - // Add combine op - std::string combine_op_name(ir::CombineOp::name()); - ir::OpInfo op_info = ctx->GetRegisteredOpInfo(combine_op_name); - - ir::Type target_vec_type = - ir::VectorType::get(ctx, types_in_vec); - ir::Operation* operation = ir::Operation::Create( - inner_inputs, {}, {target_vec_type}, op_info); - - new_in = operation->result(0); - program->block()->push_back(operation); + auto out_place = phi::TransToPhiPlace(kernel.InputAt(i).backend); + pir::Type out_type; + if (in_i_type.isa()) { + out_type = dialect::AllocatedDenseTensorType::get( + ctx, + out_place, + pre_define_op->operand_source(j) + .type() + .dyn_cast()); + } else if (in_i_type.isa()) { + out_type = dialect::AllocatedSelectedRowsType::get( + ctx, + out_place, + pre_define_op->operand_source(j) + .type() + .dyn_cast()); + } else { + PADDLE_THROW(phi::errors::Unimplemented( + "builtin.combine Input type only support " + "VectorType and " + "VectorType")); } + in_i = AddPlaceTransferOp( + in_i, out_type, place, out_place, kernel_key, program); + + is_trans = true; } - } else if (new_in_type.isa()) { - // do nothing here - } else { - PADDLE_THROW(phi::errors::Unimplemented( - "only support allocated dense tensor type for now")); + inner_inputs.push_back(in_i); + types_in_vec.push_back(in_i.type()); + } + if (is_trans) { + // Add combine op + std::string combine_op_name(pir::CombineOp::name()); + pir::OpInfo op_info = ctx->GetRegisteredOpInfo(combine_op_name); + + pir::Type target_vec_type = pir::VectorType::get(ctx, types_in_vec); + pir::Operation* operation = pir::Operation::Create( + inner_inputs, {}, {target_vec_type}, op_info); + + new_in = operation->result(0); + program->block()->push_back(operation); } } - vec_inputs.push_back(new_in); + + } else if (new_in_type.isa()) { + // do nothing here + } else { + PADDLE_THROW(phi::errors::Unimplemented( + "only support allocated dense tensor type for now")); } } + vec_inputs.push_back(new_in); + } + + return vec_inputs; +} + +void AddShadowFeed( + const phi::Place& place, + pir::Operation* op_item, + pir::Operation* kernel_op, + pir::Program* program, + pir::IrContext* ctx, + std::unordered_map* map_op_pair, + std::unordered_map* map_value_pair) { + bool feed_op_add_shadow_feed = + (op_item->name() == "pd_op.feed") && platform::is_gpu_place(place); + bool data_op_add_shadow_feed = + (op_item->name() == "pd_op.data") && platform::is_gpu_place(place) && + (kernel_op->attributes() + .at("place") + .dyn_cast() + .data() + .GetType() == phi::AllocationType::UNDEFINED); + bool add_shadow_feed = feed_op_add_shadow_feed || data_op_add_shadow_feed; + if (add_shadow_feed) { + // if shadow data op place not gpu,add shadow feed op + phi::KernelKey shadow_key{ + phi::Backend::GPU, + phi::DataLayout::ANY, + TransToPhiDataType( + op_item->result(0).type().dyn_cast().dtype())}; + std::unordered_map attr_map{ + {"op_name", pir::StrAttribute::get(ctx, "pd_op.shadow_feed")}, + {"kernel_name", pir::StrAttribute::get(ctx, "shadow_feed")}, + {"kernel_key", dialect::KernelAttribute::get(ctx, shadow_key)}}; + + auto out_type = paddle::dialect::AllocatedDenseTensorType::get( + ctx, + phi::TransToPhiPlace(shadow_key.backend()), + op_item->result(0).type().dyn_cast()); - std::unordered_map op_attribute{ - {"op_name", ir::StrAttribute::get(ctx, op_item->name())}, - {"kernel_name", ir::StrAttribute::get(ctx, kernel_fn_str)}, - {"kernel_key", dialect::KernelAttribute::get(ctx, kernel_key)}}; - auto op_attr_map = op_item->attributes(); + pir::OpInfo phi_kernel_op_info = + ctx->GetRegisteredOpInfo(paddle::dialect::PhiKernelOp::name()); + pir::Operation* shadow_op = pir::Operation::Create( + {kernel_op->result(0)}, attr_map, {out_type}, phi_kernel_op_info); - for (auto& map_item : op_attr_map) { - op_attribute.emplace(map_item.first, map_item.second); + (*map_op_pair)[op_item] = shadow_op; + program->block()->push_back(shadow_op); + if (op_item->num_results() > 0) { + for (size_t i = 0; i < shadow_op->num_results(); ++i) { + (*map_value_pair)[op_item->result(i)] = shadow_op->result(i); + } } + } +} + +std::unique_ptr GetOpYamlInfoParser(pir::Operation* op) { + paddle::dialect::OpYamlInfoInterface op_info_interface = + op->dyn_cast(); + + std::unique_ptr op_info_parser(nullptr); + if (op_info_interface) { + op_info_parser = + std::make_unique(op_info_interface.GetOpInfo()); + } + + return op_info_parser; +} + +std::string GetKernelFnStr(const OpYamlInfoParser* op_info_parser, + pir::Operation* op_item) { + std::string kernel_fn_str; + if (op_info_parser != nullptr) { + kernel_fn_str = op_info_parser->OpRuntimeInfo().kernel_func[0]; + } - if (op_item->HasTrait()) { - op_attribute.emplace("is_inplace", ir::BoolAttribute::get(ctx, true)); + if (op_item->name() == "pd_op.add_n_" || + op_item->name() == "pd_op.add_n_with_kernel") { + if (op_item->result(0).type().isa()) { + kernel_fn_str = "add_n_sr"; } + } + return kernel_fn_str; +} - ir::Operation* op; - if (dialect::IsLegacyOp(op_item->name())) { - op = ir::Operation::Create( - vec_inputs, op_attribute, op_output_types, legacy_kernel_op_info); - } else { - op = ir::Operation::Create( - vec_inputs, op_attribute, op_output_types, phi_kernel_op_info); +pir::Operation* BuildPhiKernelOp( + const std::string& kernel_fn_str, + const phi::KernelKey& kernel_key, + const std::vector& vec_inputs, + const std::vector& op_output_types, + pir::Operation* op_item, + pir::Program* program, + pir::IrContext* ctx, + std::unordered_map* map_op_pair, + std::unordered_map* map_value_pair) { + std::unordered_map op_attribute{ + {"op_name", pir::StrAttribute::get(ctx, op_item->name())}, + {"kernel_name", pir::StrAttribute::get(ctx, kernel_fn_str)}, + {"kernel_key", dialect::KernelAttribute::get(ctx, kernel_key)}}; + auto op_attr_map = op_item->attributes(); + + for (auto& map_item : op_attr_map) { + op_attribute.emplace(map_item.first, map_item.second); + } + + if (op_item->HasTrait()) { + op_attribute.emplace("is_inplace", pir::BoolAttribute::get(ctx, true)); + } + + pir::OpInfo phi_kernel_op_info = + ctx->GetRegisteredOpInfo(paddle::dialect::PhiKernelOp::name()); + + pir::OpInfo legacy_kernel_op_info = + ctx->GetRegisteredOpInfo(paddle::dialect::LegacyKernelOp::name()); + pir::Operation* op; + if (dialect::IsLegacyOp(op_item->name())) { + op = pir::Operation::Create( + vec_inputs, op_attribute, op_output_types, legacy_kernel_op_info); + } else { + op = pir::Operation::Create( + vec_inputs, op_attribute, op_output_types, phi_kernel_op_info); + } + + (*map_op_pair)[op_item] = op; + + // only deal with single output + if (op_item->num_results() > 0) { + for (size_t i = 0; i < op_item->num_results(); ++i) { + (*map_value_pair)[op_item->result(i)] = op->result(i); } + } + program->block()->push_back(op); - map_op_pair[op_item] = op; + return op; +} - // only deal with single output - if (op_item->num_results() > 0) { - for (size_t i = 0; i < op_item->num_results(); ++i) { - map_value_pair[op_item->result(i)] = op->result(i); - } +std::unique_ptr PdOpLowerToKernelPass(pir::Program* prog, + phi::Place place) { + if (VLOG_IS_ON(2)) { + std::stringstream ss; + prog->Print(ss); + VLOG(2) << "Program after lowering to kernel pass : " << ss.str(); + } + + auto program = std::make_unique(pir::IrContext::Instance()); + + auto block = prog->block(); + + pir::IrContext* ctx = pir::IrContext::Instance(); + ctx->GetOrRegisterDialect(); + ctx->GetOrRegisterDialect(); + + std::unordered_map map_op_pair; + std::unordered_map map_value_pair; + + auto skip_feed_names = GetSkipFeedNames(block); + + for (auto op_item : *block) { + VLOG(6) << "op name " << op_item->name(); + if ((op_item->name() == "pd_op.feed") && + SkipFeedOp(op_item, skip_feed_names)) { + continue; } - program->block()->push_back(op); - bool feed_op_add_shadow_feed = - (op_item->name() == "pd.feed") && platform::is_gpu_place(place); - bool data_op_add_shadow_feed = (op_item->name() == "pd.data") && - platform::is_gpu_place(place) && - (op->attributes() - .at("place") - .dyn_cast() - .data() - .GetType() != phi::AllocationType::GPU); - bool add_shadow_feed = feed_op_add_shadow_feed || data_op_add_shadow_feed; - if (add_shadow_feed) { - // if shadow data op place not gpu,add shadow feed op - phi::KernelKey shadow_key{ - phi::Backend::GPU, - phi::DataLayout::ANY, - TransToPhiDataType( - op_item->result(0).type().dyn_cast().dtype())}; - std::unordered_map attr_map{ - {"op_name", ir::StrAttribute::get(ctx, "pd.shadow_feed")}, - {"kernel_name", ir::StrAttribute::get(ctx, "shadow_feed")}, - {"kernel_key", dialect::KernelAttribute::get(ctx, shadow_key)}}; - - auto out_type = paddle::dialect::AllocatedDenseTensorType::get( - ctx, - phi::TransToPhiPlace(shadow_key.backend()), - op_item->result(0).type().dyn_cast()); - - ir::Operation* shadow_op = ir::Operation::Create( - {op->result(0)}, attr_map, {out_type}, phi_kernel_op_info); - - map_op_pair[op_item] = shadow_op; - program->block()->push_back(shadow_op); - if (op_item->num_results() > 0) { - for (size_t i = 0; i < shadow_op->num_results(); ++i) { - map_value_pair[op_item->result(i)] = shadow_op->result(i); - } - } + // HandleSpecialOp + if (SpecialLowerOps.count(op_item->name())) { + HandleForSpecialOp( + op_item, program.get(), ctx, &map_op_pair, &map_value_pair); + continue; } + + // Lower from PaddleDialect to KernelDialect + + auto op_info_parser = GetOpYamlInfoParser(op_item); + + auto kernel_fn_str = GetKernelFnStr(op_info_parser.get(), op_item); + + auto kernel_key = GetKernelKey( + op_item, place, kernel_fn_str, map_value_pair, op_info_parser.get()); + VLOG(6) << "kernel type " << kernel_key; + + // build output type + auto op_output_types = + BuildOpOutputType(op_item, kernel_fn_str, kernel_key, ctx); + + // build input + auto vec_inputs = BuildOpInputList(op_item, + kernel_fn_str, + kernel_key, + place, + op_info_parser.get(), + ctx, + &map_op_pair, + &map_value_pair, + program.get()); + + // build op + pir::Operation* op = BuildPhiKernelOp(kernel_fn_str, + kernel_key, + vec_inputs, + op_output_types, + op_item, + program.get(), + ctx, + &map_op_pair, + &map_value_pair); + + AddShadowFeed( + place, op_item, op, program.get(), ctx, &map_op_pair, &map_value_pair); } + if (VLOG_IS_ON(2)) { std::stringstream ss1; program->Print(ss1); @@ -1184,6 +1219,5 @@ std::unique_ptr PdOpLowerToKernelPass(ir::Program* prog, } return program; } - } // namespace dialect } // namespace paddle diff --git a/paddle/fluid/ir/transforms/pd_op_to_kernel_pass.h b/paddle/fluid/pir/transforms/pd_op_to_kernel_pass.h similarity index 83% rename from paddle/fluid/ir/transforms/pd_op_to_kernel_pass.h rename to paddle/fluid/pir/transforms/pd_op_to_kernel_pass.h index 3e4848720f4ce..acf839391b8c5 100644 --- a/paddle/fluid/ir/transforms/pd_op_to_kernel_pass.h +++ b/paddle/fluid/pir/transforms/pd_op_to_kernel_pass.h @@ -13,14 +13,14 @@ // limitations under the License. #pragma once -#include "paddle/ir/core/program.h" #include "paddle/phi/common/place.h" +#include "paddle/pir/core/program.h" namespace paddle { namespace dialect { -std::unique_ptr PdOpLowerToKernelPass( - ir::Program* prog, phi::Place place = phi::CPUPlace()); +std::unique_ptr PdOpLowerToKernelPass( + pir::Program* prog, phi::Place place = phi::CPUPlace()); } // namespace dialect } // namespace paddle diff --git a/paddle/fluid/ir/transforms/transform_general_functions.cc b/paddle/fluid/pir/transforms/transform_general_functions.cc similarity index 74% rename from paddle/fluid/ir/transforms/transform_general_functions.cc rename to paddle/fluid/pir/transforms/transform_general_functions.cc index 587c0cdaacd1d..6da131ee5e0c0 100644 --- a/paddle/fluid/ir/transforms/transform_general_functions.cc +++ b/paddle/fluid/pir/transforms/transform_general_functions.cc @@ -12,36 +12,38 @@ // See the License for the specific language governing permissions and // limitations under the License. -#include "paddle/fluid/ir/transforms/transform_general_functions.h" +#include "paddle/fluid/pir/transforms/transform_general_functions.h" -#include "paddle/fluid/ir/dialect/paddle_dialect/ir/pd_dialect.h" -#include "paddle/fluid/ir/dialect/paddle_dialect/ir/pd_type.h" -#include "paddle/ir/core/builtin_op.h" -#include "paddle/ir/core/parameter.h" -#include "paddle/ir/core/program.h" +#include "paddle/fluid/pir/dialect/operator/ir/op_dialect.h" +#include "paddle/fluid/pir/dialect/operator/ir/op_type.h" +#include "paddle/pir/core/builtin_op.h" +#include "paddle/pir/core/parameter.h" +#include "paddle/pir/core/program.h" -namespace ir { +namespace pir { -std::pair GetParameterFromValue(ir::Value value) { - ir::GetParameterOp op = value.GetDefiningOp()->dyn_cast(); +std::pair GetParameterFromValue( + pir::Value value) { + pir::GetParameterOp op = + value.GetDefiningOp()->dyn_cast(); PADDLE_ENFORCE_NOT_NULL( op, phi::errors::InvalidArgument( "Value must be a weight from a GetParameter op.")); - ir::Program* program = op->GetParentProgram(); + pir::Program* program = op->GetParentProgram(); PADDLE_ENFORCE_NOT_NULL( program, phi::errors::InvalidArgument("Program should not be null.")); std::string name = op->attributes() .at(op.attributes_name[0]) - .dyn_cast() + .dyn_cast() .AsString(); - ir::Parameter* param = program->GetParameter(name); + pir::Parameter* param = program->GetParameter(name); PADDLE_ENFORCE_NOT_NULL( param, phi::errors::InvalidArgument("Parameter should not be null.")); return {name, param}; } -const phi::DDim& GetShapeFromValue(ir::Value value) { +const phi::DDim& GetShapeFromValue(pir::Value value) { // TODO(dev): Support other types like DenseTensor. PADDLE_ENFORCE_EQ( value.type().isa(), @@ -50,7 +52,7 @@ const phi::DDim& GetShapeFromValue(ir::Value value) { return value.type().dyn_cast().dims(); } -ir::Type GetDataTypeFromValue(ir::Value value) { +pir::Type GetDataTypeFromValue(pir::Value value) { // TODO(dev): Support other types like DenseTensor. PADDLE_ENFORCE_EQ( value.type().isa(), @@ -75,4 +77,4 @@ Operation* GetFirstUseOperationForOutput(Operation* op, uint32_t index) { return op->result(index).first_use().owner(); } -} // namespace ir +} // namespace pir diff --git a/paddle/fluid/ir/transforms/transform_general_functions.h b/paddle/fluid/pir/transforms/transform_general_functions.h similarity index 76% rename from paddle/fluid/ir/transforms/transform_general_functions.h rename to paddle/fluid/pir/transforms/transform_general_functions.h index b086af090f7a1..77c790235b832 100644 --- a/paddle/fluid/ir/transforms/transform_general_functions.h +++ b/paddle/fluid/pir/transforms/transform_general_functions.h @@ -14,45 +14,45 @@ #pragma once -#include "paddle/ir/core/operation.h" -#include "paddle/ir/core/parameter.h" -#include "paddle/ir/core/type.h" -#include "paddle/ir/core/value.h" #include "paddle/phi/core/ddim.h" #include "paddle/phi/core/enforce.h" #include "paddle/phi/core/errors.h" +#include "paddle/pir/core/operation.h" +#include "paddle/pir/core/parameter.h" +#include "paddle/pir/core/type.h" +#include "paddle/pir/core/value.h" -namespace ir { +namespace pir { /** * @brief Get the [name, parameter] pair of pararmeter from a value. * * @note The value must be a output of a GetParameterOp. * - * @param ir::Value + * @param pir::Value * - * @return std::pair + * @return std::pair */ -std::pair GetParameterFromValue(ir::Value value); +std::pair GetParameterFromValue(pir::Value value); /** * @brief Get tensor's shape from a value. * - * @param ir::Value + * @param pir::Value * * @return const phi::DDim& */ -const phi::DDim& GetShapeFromValue(ir::Value value); +const phi::DDim& GetShapeFromValue(pir::Value value); /** * @brief Get tensor's data type from a value. * - * @param ir::Value + * @param pir::Value * - * @return ir::Type + * @return pir::Type */ -ir::Type GetDataTypeFromValue(ir::Value value); +pir::Type GetDataTypeFromValue(pir::Value value); /** * @brief Get an operation that defines the specific input of the operation. @@ -75,4 +75,4 @@ Operation* GetDefiningOpForInput(Operation* op, uint32_t index); */ Operation* GetFirstUseOperationForOutput(Operation* op, uint32_t index); -} // namespace ir +} // namespace pir diff --git a/paddle/fluid/platform/init.cc b/paddle/fluid/platform/init.cc index 9343c8ddf7781..eae360c146df5 100644 --- a/paddle/fluid/platform/init.cc +++ b/paddle/fluid/platform/init.cc @@ -63,11 +63,7 @@ limitations under the License. */ #endif PHI_DECLARE_int32(paddle_num_threads); -PADDLE_DEFINE_EXPORTED_int32( - multiple_of_cupti_buffer_size, - 1, - "Multiple of the CUPTI device buffer size. If the timestamps have " - "been dropped when you are profiling, try increasing this value."); +PHI_DECLARE_int32(multiple_of_cupti_buffer_size); namespace paddle { namespace framework { diff --git a/paddle/fluid/primitive/backend/CMakeLists.txt b/paddle/fluid/primitive/backend/CMakeLists.txt index deabc1f19d9b5..d352880871121 100644 --- a/paddle/fluid/primitive/backend/CMakeLists.txt +++ b/paddle/fluid/primitive/backend/CMakeLists.txt @@ -12,4 +12,4 @@ set(static_backend_files cc_library( primitive_backend_static_experimental SRCS ${static_backend_files} - DEPS pd_dialect_api) + DEPS pd_op_dialect_api) diff --git a/paddle/fluid/primitive/backend/manual/manual_static_backend.cc b/paddle/fluid/primitive/backend/manual/manual_static_backend.cc index de39a58473337..7d96b4ddfecc2 100644 --- a/paddle/fluid/primitive/backend/manual/manual_static_backend.cc +++ b/paddle/fluid/primitive/backend/manual/manual_static_backend.cc @@ -12,7 +12,7 @@ // See the License for the specific language governing permissions and // limitations under the License. -#include "paddle/fluid/ir/dialect/paddle_dialect/ir/pd_api.h" +#include "paddle/fluid/pir/dialect/operator/ir/pd_api.h" #include "paddle/fluid/primitive/backend/manual/manual_backend.h" #include "paddle/fluid/primitive/primitive/primitive.h" #include "paddle/fluid/primitive/type/lazy_tensor.h" diff --git a/paddle/fluid/primitive/codegen/gen.py b/paddle/fluid/primitive/codegen/gen.py index 43eb5005f0f52..815f41e6fdb03 100644 --- a/paddle/fluid/primitive/codegen/gen.py +++ b/paddle/fluid/primitive/codegen/gen.py @@ -16,7 +16,6 @@ import hashlib import pathlib import sys -from typing import Dict, List import jinja2 import yaml @@ -28,10 +27,11 @@ ) import filters as op_gen_filters import tests_utils as op_gen_tests +from parse_utils import to_named_dict -# import from paddle/fluid/ir/dialect/op_generator/api_gen.py +# import from paddle/fluid/pir/dialect/op_generator/api_gen.py sys.path.append( - str(pathlib.Path(__file__).resolve().parents[2] / 'ir/dialect/op_generator') + str(pathlib.Path(__file__).resolve().parents[2] / 'pir/dialect/op_generator') ) # fmt: on @@ -61,9 +61,21 @@ 'rsqrt_grad', 'slice_grad', 'transpose_grad', + 'square_grad', 'dropout_grad', + 'cast_grad', + 'slice_double_grad', + 'layer_norm_grad', + 'embedding_grad', + 'add_n_grad', + 'scale_grad', ] -VJP_COMPS = ['divide_grad', 'sum_grad', 'gelu_grad'] + + +PRIM_VJP = ['divide_grad', 'sum_grad'] # vjp list of primitive op +CUSTOM_VJP = ['gelu_grad'] # custom vjp list of composite op +VJP_COMPS = PRIM_VJP + CUSTOM_VJP + BACKENDS = [ 'add_n', 'mean', @@ -129,7 +141,14 @@ 'roll', 'scatter', 'scatter_nd_add', + 'square_grad', 'dropout_grad', + 'slice', + 'layer_norm_grad', + 'embedding_grad', + 'add_n_grad', + 'sqrt', + 'uniform', ] @@ -219,21 +238,6 @@ def save(content: str, path: pathlib.Path): print(f"Generate source file {path}") -def to_compat_dict(items: List[Dict]) -> Dict[str, Dict]: - compat_dict = {} - for item in items: - name = item["op"] - compat_dict[name] = item - return compat_dict - - -def to_apis_dict(apis): - apis_dict = {} - for api in apis: - apis_dict[api['name']] = api - return apis_dict - - def get_inplace_api(apis): inplace_apis = [] for api in apis: @@ -271,7 +275,7 @@ def extend_compat_info(apis, compats): attr['typename'] ) or op_gen_tests.is_intarray(attr['typename']): attr["support_tensor"] = False - apis_dict = to_apis_dict(apis) + apis_dict = to_named_dict(apis) for compat_item in compats: fwd_op_name = compat_item["op"] if fwd_op_name not in apis_dict: @@ -322,6 +326,31 @@ def extend_compat_info(apis, compats): return apis +def process_backward_invoke_info(apis): + apis_dict = to_named_dict(apis) + for api in apis: + if api['is_fwd']: + continue + if 'invoke' in api and api['invoke']['func'] in apis_dict: + args = api['invoke']['args'].split(',') + args = [arg.strip() for arg in args] + attrs_dict = to_named_dict(api['attrs']) + inputs_dict = to_named_dict(api['inputs']) + arg_inputs = [] + arg_attrs = [] + for arg in args: + if arg in inputs_dict: + arg_inputs.append(arg) + elif arg in attrs_dict and attrs_dict[arg].get( + "support_tensor", False + ): + arg_inputs.append(arg + '_') + else: + arg_attrs.append(arg) + args = arg_inputs + arg_attrs + api['invoke']['args'] = ', '.join(args) + + def gen( prim_path: pathlib.Path, fwd_path: pathlib.Path, @@ -369,6 +398,7 @@ def gen( ] apis = extend_compat_info(apis, compats) apis = apis + get_inplace_api(apis) + process_backward_invoke_info(apis) render( templates_dir, destination_dir, diff --git a/paddle/fluid/primitive/codegen/templates/backend/generated/generated_backend.h.j2 b/paddle/fluid/primitive/codegen/templates/backend/generated/generated_backend.h.j2 index 3bbd00d967b83..663467af25a97 100644 --- a/paddle/fluid/primitive/codegen/templates/backend/generated/generated_backend.h.j2 +++ b/paddle/fluid/primitive/codegen/templates/backend/generated/generated_backend.h.j2 @@ -7,6 +7,7 @@ #include #include "paddle/phi/api/include/tensor.h" +#include "paddle/utils/optional.h" namespace paddle { diff --git a/paddle/fluid/primitive/codegen/templates/backend/generated/generated_static_backend.cc.j2 b/paddle/fluid/primitive/codegen/templates/backend/generated/generated_static_backend.cc.j2 index 8e004c22eeeb5..48292d27243e6 100644 --- a/paddle/fluid/primitive/codegen/templates/backend/generated/generated_static_backend.cc.j2 +++ b/paddle/fluid/primitive/codegen/templates/backend/generated/generated_static_backend.cc.j2 @@ -2,7 +2,7 @@ // Auto Generated, DO NOT EDIT! #include "paddle/fluid/primitive/backend/generated/generated_backend.h" -#include "paddle/fluid/ir/dialect/paddle_dialect/ir/pd_api.h" +#include "paddle/fluid/pir/dialect/operator/ir/pd_api.h" #include "paddle/fluid/primitive/primitive/primitive.h" #include "paddle/fluid/primitive/type/lazy_tensor.h" @@ -17,22 +17,69 @@ template <> {{common.ret(outputs)}} {{name}}({{common.params(inputs, attrs, mutable_attribute_as_inputs, False)}}) {%- endmacro -%} -{% macro body(name, inputs, outputs, attrs, mutable_attribute_as_inputs=False) %} - {%- set output_names = [] -%} - {%- for o in outputs -%} {%- do output_names.append(o.name) -%} {%-endfor-%} +{%- macro prepare_ir_api_inputs(inputs)-%} {%- for input in inputs -%} - {% if input.typename=='Tensor[]' %} - std::vector {{input.name}}_res({{input.name}}.size()); - std::transform({{input.name}}.begin(), {{input.name}}.end(), {{input.name}}_res.begin(), [](const Tensor& t) { - return std::static_pointer_cast(t.impl())->getValue().dyn_cast(); + {% if input.typename=='Tensor[]' and not input.optional %} +std::vector {{input.name}}_res({{input.name}}.size()); +std::transform({{input.name}}.begin(), {{input.name}}.end(), {{input.name}}_res.begin(), [](const Tensor& t) { + return std::static_pointer_cast(t.impl())->value().dyn_cast(); +}); + {% elif input.typename=='Tensor[]' and input.optional %} +std::vector {{input.name}}_res({{input.name}}.size()); +if({{input.name}}) { + std::transform({{input.name}}.get().begin(), {{input.name}}.get().end(), {{input.name}}_res.begin(), [](const Tensor& t) { + return std::static_pointer_cast(t.impl())->value().dyn_cast(); }); +} + {% elif input.typename=='Tensor' and not input.optional %} +pir::OpResult {{input.name}}_res = std::static_pointer_cast({{input.name}}.impl())->value().dyn_cast(); {% else %} - ir::OpResult {{input.name}}_res = std::static_pointer_cast({{input.name}}.impl())->getValue().dyn_cast(); +pir::OpResult {{input.name}}_res; +if({{input.name}}) { + {{input.name}}_res = std::static_pointer_cast({{input.name}}.get().impl())->value().dyn_cast(); +} {% endif %} {% endfor %} - {%- for attr in attrs -%} +{%- endmacro -%} + +{%- macro get_static_backend_outputs(outputs)-%} + {%- if outputs|length == 1 -%} + {%- if outputs[0].typename == 'Tensor' -%} +Tensor {{outputs[0].name}}(std::make_shared(op_res)); +return {{outputs[0].name}}; + {%- elif outputs[0].typename == 'Tensor[]' -%} +std::vector {{outputs[0].name}}(op_res.size()); +std::transform(op_res.begin(), op_res.end(), {{outputs[0].name}}.begin(), [](const pir::OpResult& res) { +return Tensor(std::make_shared(res)); + }); +return {{outputs[0].name}}; + {%- else -%} {#- render nothing -#} + {%- endif -%} + {%- elif outputs|length > 1 -%} + {%- for i in range(outputs|length) %} +auto op_res_{{i}} = std::get<{{i}}>(op_res); + {% if outputs[i].typename == 'Tensor' %} +Tensor {{outputs[i].name}}(std::make_shared(op_res_{{i}})); + {% elif outputs[i].typename == 'Tensor[]' %} +std::vector {{outputs[i].name}}(op_res_{{i}}.size()); +std::transform(op_res_{{i}}.begin(), op_res_{{i}}.end(), {{outputs[i].name}}.begin(), [](const pir::OpResult& res) { +return Tensor(std::make_shared(res)); + }); + {% else %} {#- render nothing -#} + {% endif %} + {% endfor -%} +return std::make_tuple({%- for i in range(outputs|length) -%}{{outputs[i].name}}{%- if i!=outputs|length - 1 -%}, {% endif -%}{%- endfor -%}); + {%- else -%} {#- render nothing -#} + {%- endif -%} +{%- endmacro -%} + +{% macro body(name, inputs, outputs, attrs, mutable_attribute_as_inputs=False) %} + {%- set output_names = [] -%} + {%- for o in outputs -%} {%- do output_names.append(o.name) -%} {%-endfor-%} +{{prepare_ir_api_inputs(inputs)}} + {%- for attr in attrs %} {% if mutable_attribute_as_inputs and attr is mutable_attribute %} - ir::OpResult {{attr.name}}_res = std::static_pointer_cast({{attr.name~'_'}}.impl())->getValue().dyn_cast(); +pir::OpResult {{attr.name}}_res = std::static_pointer_cast({{attr.name~'_'}}.impl())->value().dyn_cast(); {% endif %} {% endfor %} {%- set input_names = [] -%} @@ -52,48 +99,25 @@ template <> {%- do attr_names.append(common.phi2ir_attr(i)) -%} {%- endif -%} {% endfor %} - auto op_res = paddle::dialect::{{name}}({{common.args(input_names, attr_names)}}); - {% if outputs|length == 1 %} - {% if outputs[0].typename == 'Tensor' %} - Tensor {{outputs[0].name}}(std::make_shared(op_res)); - return {{outputs[0].name}}; - {% elif outputs[0].typename == 'Tensor[]' %} - std::vector {{outputs[0].name}}(op_res.size()); - std::transform(op_res.begin(), op_res.end(), {{outputs[0].name}}.begin(), [](const ir::OpResult& res) { - return Tensor(std::make_shared(res)); - }); - return {{outputs[0].name}}; - {% else %} {#- render nothing -#} - {% endif %} - {% elif outputs|length > 1 %} - {% for i in range(outputs|length) %} - auto op_res_{{i}} = std::get<{{i}}>(op_res); - {% if outputs[i].typename == 'Tensor' %} - Tensor {{outputs[i].name}}(std::make_shared(op_res_{{i}})); - {% elif outputs[i].typename == 'Tensor[]' %} - std::vector {{outputs[i].name}}(op_res_{{i}}.size()); - std::transform(op_res_{{i}}.begin(), op_res_{{i}}.end(), {{outputs[i].name}}.begin(), [](const ir::OpResult& res) { - return Tensor(std::make_shared(res)); - }); - {% else %} {#- render nothing -#} - {% endif %} - {% endfor %} - return std::make_tuple({% for i in range(outputs|length) %}{{outputs[i].name}}{%- if i!=outputs|length - 1 -%}, {% endif %}{% endfor %}); - {% else %} {#- render nothing -#} - {% endif %} -{% endmacro %} +auto op_res = paddle::dialect::{{name}}({{common.args(input_names, attr_names)}}); +{{get_static_backend_outputs(outputs)}} +{%- endmacro %} {% for api in apis %} {% if api.name in backend_white_list %} {% set api_outputs = api.outputs | trip_intermediate %} {{sig(api.name, api.inputs, api_outputs, api.attrs)}} { + {% filter indent(2, True) %} {{body(api.name, api.inputs, api_outputs, api.attrs)}} + {% endfilter %} } {% if api.attrs is exist_mutable_attribute %} {{sig(api.name, api.inputs, api_outputs, api.attrs, True)}} { + {% filter indent(2, True) %} {{body(api.name, api.inputs, api_outputs, api.attrs, True)}} + {% endfilter %} } {% endif %} diff --git a/paddle/fluid/primitive/codegen/templates/rule/vjp/generated/generated_vjp.cc.j2 b/paddle/fluid/primitive/codegen/templates/rule/vjp/generated/generated_vjp.cc.j2 index 6d69433737633..67485bdd5a5cd 100644 --- a/paddle/fluid/primitive/codegen/templates/rule/vjp/generated/generated_vjp.cc.j2 +++ b/paddle/fluid/primitive/codegen/templates/rule/vjp/generated/generated_vjp.cc.j2 @@ -2,15 +2,16 @@ // Auto Generated, DO NOT EDIT! #include "paddle/fluid/primitive/rule/vjp/generated/generated_vjp.h" -#include "paddle/fluid/ir/dialect/paddle_dialect/ir/pd_api.h" -#include "paddle/fluid/ir/dialect/paddle_dialect/ir/pd_attribute.h" +#include "paddle/fluid/pir/dialect/operator/ir/pd_api.h" +#include "paddle/fluid/pir/dialect/operator/ir/op_attribute.h" #include "paddle/fluid/prim/utils/static/static_global_utils.h" #include "paddle/fluid/primitive/backend/backend.h" #include "paddle/fluid/primitive/rule/vjp/details.h" #include "paddle/fluid/primitive/type/lazy_tensor.h" #include "paddle/fluid/primitive/utils/utils.h" -#include "paddle/ir/core/operation.h" +#include "paddle/pir/core/operation.h" #include "paddle/phi/core/flags.h" +#include "paddle/utils/optional.h" PHI_DECLARE_string(tensor_operants_mode); @@ -33,23 +34,23 @@ if (paddle::prim::StaticCompositeContext::Instance().IsBwdPrimEnabled()) { } {% else %} {{body_unprim(api)}} - {% endif %} + {%- endif %} return vjp_res; -{% endmacro %} +{%- endmacro -%} {% macro get_mutable_attribute(attrs, api_name) %} {% for i in attrs %} {%- if i is mutable_attribute -%} -auto* {{i.name}}_define_op = std::static_pointer_cast({{i.name~'_'}}.impl())->getValue().dyn_cast().GetDefiningOp(); +auto* {{i.name}}_define_op = std::static_pointer_cast({{i.name~'_'}}.impl())->value().dyn_cast().GetDefiningOp(); {% if i.typename is scalar %} -if({{i.name}}_define_op->name() != "pd.full") { +if({{i.name}}_define_op->name() != "pd_op.full") { PADDLE_THROW(platform::errors::Unimplemented( "We don't support dynamic tensors attribute {{i.name}} for {{api_name}} composite " "for now. ")); } auto {{i.name}} = {{i.name}}_define_op->attribute("value").dyn_cast().data(); {% elif i.typename is intarray %} -if({{i.name}}_define_op->name() != "pd.full_int_array"){ +if({{i.name}}_define_op->name() != "pd_op.full_int_array"){ PADDLE_THROW(platform::errors::Unimplemented( "We don't support dynamic tensors attribute {{i.name}} for {{api_name}} composite " "for now. ")); @@ -62,6 +63,7 @@ auto {{i.name}} = {{i.name}}_define_op->attribute("value").dyn_castattribute("value").dyn_cast({{api.invoke.args}}); + {% else %} auto op_res = backend::{{api.name}}({{common.args(input_names, attr_names)}}); + {% endif %} {% set outputs = api.outputs|trip_intermediate %} {#- ignore intermediate output -#} {% if outputs|length > 1 %} {% for i in range(outputs|length) %} -auto out{{i}} = std::get<{{i}}>(op_res); {% if outputs[i].typename=='Tensor' %} -vjp_res[{{i}}][0] = !stop_gradients[{{i}}][0] ? out{{i}} : vjp_res[{{i}}][0]; +vjp_res[{{i}}][0] = std::get<{{i}}>(op_res); {% else %} -for (size_t i=0; i< stop_gradients[{{i}}].size(); i++ ) { - vjp_res[{{i}}][i] = !stop_gradients[{{i}}][i] ? out{{i}}[i] : vjp_res[{{i}}][i]; -} +vjp_res[{{i}}] = std::get<{{i}}>(op_res); {% endif %} {% endfor %} {% elif outputs|length == 1 %} {% if outputs[0].typename=='Tensor' %} -vjp_res[0][0] = !stop_gradients[0][0] ? op_res : vjp_res[0][0]; +vjp_res[0][0] = op_res; {% else %} -for (size_t i=0; i< stop_gradients[0].size(); i++ ) { - vjp_res[0][i] = !stop_gradients[0][i] ? op_res[i] : vjp_res[0][i]; -} +vjp_res[0] = op_res; {% endif %} {% else %} {#- render nothing -#} {% endif %} +vjp_res = ConstructVjpResultByStopGradients(vjp_res, stop_gradients); {% endmacro %} {% macro body_prim(api) %} @@ -120,7 +122,7 @@ details::{{api.composite.func_name}}({{api.composite.func_args}}); {{sig(api.name, backward_api.name, backward_api.inputs, backward_api.attrs, backward_api.outputs)}} { {% filter indent(2, True) %} {{body(backward_api)}} - {% endfilter %} + {% endfilter -%} } {% endif %} diff --git a/paddle/fluid/primitive/codegen/templates/rule/vjp/generated/generated_vjp.h.j2 b/paddle/fluid/primitive/codegen/templates/rule/vjp/generated/generated_vjp.h.j2 index b9e758aaa73ff..7f403661fea05 100644 --- a/paddle/fluid/primitive/codegen/templates/rule/vjp/generated/generated_vjp.h.j2 +++ b/paddle/fluid/primitive/codegen/templates/rule/vjp/generated/generated_vjp.h.j2 @@ -4,7 +4,7 @@ #pragma once #include "paddle/fluid/primitive/primitive/primitive.h" -#include "paddle/ir/core/value.h" +#include "paddle/pir/core/value.h" #include "paddle/phi/api/include/tensor.h" #include "paddle/phi/common/int_array.h" diff --git a/paddle/fluid/primitive/rule/vjp/CMakeLists.txt b/paddle/fluid/primitive/rule/vjp/CMakeLists.txt index 3d6906bb33ca5..4b790fd07900b 100644 --- a/paddle/fluid/primitive/rule/vjp/CMakeLists.txt +++ b/paddle/fluid/primitive/rule/vjp/CMakeLists.txt @@ -5,4 +5,4 @@ cc_library( primitive_vjp_experimental SRCS ${VJP_SRCS} DEPS primitive_backend_static_experimental static_global_utils - primitive_static_utils_experimental pd_dialect_core) + primitive_static_utils_experimental pd_op_dialect_core) diff --git a/paddle/fluid/primitive/rule/vjp/manual/manual_vjp.cc b/paddle/fluid/primitive/rule/vjp/manual/manual_vjp.cc index c56ac5c5f79ab..a882f78c52018 100644 --- a/paddle/fluid/primitive/rule/vjp/manual/manual_vjp.cc +++ b/paddle/fluid/primitive/rule/vjp/manual/manual_vjp.cc @@ -15,13 +15,13 @@ // Auto Generated, DO NOT EDIT! #include "paddle/fluid/primitive/rule/vjp/manual/manual_vjp.h" -#include "paddle/fluid/ir/dialect/paddle_dialect/ir/pd_api.h" +#include "paddle/fluid/pir/dialect/operator/ir/pd_api.h" #include "paddle/fluid/prim/utils/static/static_global_utils.h" #include "paddle/fluid/primitive/backend/backend.h" #include "paddle/fluid/primitive/rule/vjp/details.h" #include "paddle/fluid/primitive/type/lazy_tensor.h" #include "paddle/fluid/primitive/utils/utils.h" -#include "paddle/ir/core/operation.h" +#include "paddle/pir/core/operation.h" namespace paddle { namespace primitive {} // namespace primitive diff --git a/paddle/fluid/primitive/rule/vjp/manual/manual_vjp.h b/paddle/fluid/primitive/rule/vjp/manual/manual_vjp.h index 0fffd6ba31a4c..35810f6d652ca 100644 --- a/paddle/fluid/primitive/rule/vjp/manual/manual_vjp.h +++ b/paddle/fluid/primitive/rule/vjp/manual/manual_vjp.h @@ -15,9 +15,9 @@ #pragma once #include "paddle/fluid/primitive/primitive/primitive.h" -#include "paddle/ir/core/value.h" #include "paddle/phi/api/include/tensor.h" #include "paddle/phi/common/int_array.h" +#include "paddle/pir/core/value.h" namespace paddle { namespace primitive { diff --git a/paddle/fluid/primitive/type/lazy_tensor.h b/paddle/fluid/primitive/type/lazy_tensor.h index bb0af2ef374ca..cde6ece54b163 100644 --- a/paddle/fluid/primitive/type/lazy_tensor.h +++ b/paddle/fluid/primitive/type/lazy_tensor.h @@ -13,12 +13,12 @@ // limitations under the License. #pragma once -#include "paddle/fluid/ir/dialect/paddle_dialect/ir/pd_type.h" -#include "paddle/fluid/ir/dialect/paddle_dialect/utils/utils.h" -#include "paddle/ir/core/value.h" +#include "paddle/fluid/pir/dialect/operator/ir/op_type.h" +#include "paddle/fluid/pir/dialect/operator/utils/utils.h" #include "paddle/phi/core/ddim.h" #include "paddle/phi/core/extended_tensor.h" #include "paddle/phi/core/utils/data_type.h" +#include "paddle/pir/core/value.h" namespace paddle { namespace primitive { @@ -26,7 +26,7 @@ namespace primitive { class LazyTensor : public phi::ExtendedTensor, public phi::TypeInfoTraits { public: - explicit LazyTensor(ir::Value value) + explicit LazyTensor(pir::Value value) : value_(value), dims_(value.type().dyn_cast().dims()) {} @@ -41,14 +41,16 @@ class LazyTensor : public phi::ExtendedTensor, value_.type().dyn_cast().dtype()); } - ir::Value getValue() const { return value_; } + pir::Value value() const { return value_; } const phi::Place& place() const override { return place_; } bool initialized() const override { return value_.impl() != nullptr; } + void set_empty_type() { value_.set_type(pir::Type()); } + private: - ir::Value value_; + pir::Value value_; mutable phi::DDim dims_; phi::Place place_; }; diff --git a/paddle/fluid/primitive/utils/static_utils.cc b/paddle/fluid/primitive/utils/static_utils.cc index 40cbbc8d21e89..21b970561d7c9 100644 --- a/paddle/fluid/primitive/utils/static_utils.cc +++ b/paddle/fluid/primitive/utils/static_utils.cc @@ -21,5 +21,48 @@ void set_output(const paddle::Tensor& x_tmp, paddle::Tensor* x) { x->set_impl(x_tmp.impl()); } +/** + * @brief set output with no grads in new ir. + * + * In new ir, we use None type to express + * that value is not available. + * Some outputs in vjp are marked as unnecessary + * by stop_gradient with True. Therefore the + * type of those outputs that are unnecessary will + * be set with None. + * + */ +void SetOutputWithNoGrads( + const std::vector>& outputs, + const std::vector>& stop_gradients) { + for (size_t i = 0; i < outputs.size(); ++i) { + for (size_t j = 0; j < outputs[i].size(); ++j) { + if (stop_gradients[i][j]) { + std::static_pointer_cast(outputs[i][j].impl()) + ->set_empty_type(); + } + } + } +} + +std::vector> ConstructVjpResultByStopGradients( + const std::vector>& outputs, + const std::vector>& stop_gradients) { + SetOutputWithNoGrads(outputs, stop_gradients); + std::vector> vjp_results(outputs.size()); + for (size_t i = 0; i < outputs.size(); ++i) { + vjp_results[i].reserve(outputs[i].size()); + for (size_t j = 0; j < outputs[i].size(); ++j) { + if (stop_gradients[i][j]) { + // Use Tensor's impl is nullptr to indicate it has no gradient + vjp_results[i].emplace_back(Tensor()); + } else { + vjp_results[i].emplace_back(outputs[i][j]); + } + } + } + return vjp_results; +} + } // namespace primitive } // namespace paddle diff --git a/paddle/fluid/primitive/utils/utils.h b/paddle/fluid/primitive/utils/utils.h index e1765357aa9f8..3a5205c256130 100644 --- a/paddle/fluid/primitive/utils/utils.h +++ b/paddle/fluid/primitive/utils/utils.h @@ -16,6 +16,7 @@ #include #include "paddle/fluid/operators/common_infer_shape_functions.h" +#include "paddle/fluid/primitive/type/lazy_tensor.h" #include "paddle/phi/api/include/tensor.h" #include "paddle/phi/core/ddim.h" @@ -87,5 +88,12 @@ static phi::DDim get_reduce_dims(const phi::DDim& x_dims, return get_reduce_dims_from_out(out_dims, x_dims); } +void SetOutputWithNoGrads(const std::vector>& outputs, + const std::vector>& stop_gradients); + +std::vector> ConstructVjpResultByStopGradients( + const std::vector>& outputs, + const std::vector>& stop_gradients); + } // namespace primitive } // namespace paddle diff --git a/paddle/fluid/pybind/CMakeLists.txt b/paddle/fluid/pybind/CMakeLists.txt index 30cb90a5d2042..6c0c0fb4f81f2 100755 --- a/paddle/fluid/pybind/CMakeLists.txt +++ b/paddle/fluid/pybind/CMakeLists.txt @@ -39,10 +39,10 @@ set(PYBIND_DEPS phi_utils phi phi_kernel_adaptor - pd_dialect + pd_op_dialect program_translator pd_inplace_pass - ir + pir new_profiler jit_layer jit_property @@ -344,7 +344,7 @@ if(WITH_PYTHON) add_custom_command( OUTPUT ${op_impl_path}/ir.dll COMMAND ${CMAKE_COMMAND} -E copy ${IR_LIB} ${op_impl_path} - DEPENDS ir) + DEPENDS pir) list(APPEND EAGER_OP_IMPL_DEPS ${op_impl_path}/ir.dll) endif() diff --git a/paddle/fluid/pybind/auto_parallel_py.cc b/paddle/fluid/pybind/auto_parallel_py.cc index 8cf3a4dbbab07..27d6a75ba0736 100644 --- a/paddle/fluid/pybind/auto_parallel_py.cc +++ b/paddle/fluid/pybind/auto_parallel_py.cc @@ -14,6 +14,7 @@ #include #include +#include #include "paddle/fluid/framework/block_desc.h" #include "paddle/fluid/framework/op_desc.h" @@ -381,6 +382,44 @@ void BindAutoParallel(py::module *m) { } return self.InferForward(ctx); }) + .def("infer_forward", // for op that have vector argument + [](const phi::distributed::SpmdRule &self, + const std::vector> &input_ranges, + const std::vector &input_specs, + const std::vector &attrs) { + /* + to distingish between single tensor argument and vector argument of + one tensor: start - end == 0: single tensor start - end == 1: + vector containing one tensor input_ranges: [(0, 0), (1, 3), (3, 4)] + + input_specs: [t0, t1, t2, t3] --> t0, [t1, t2], [t3] + */ + phi::distributed::InferSpmdContext ctx; + paddle::small_vector + ins; + for (auto &range : input_ranges) { + if (range.second - range.first == 0) { + auto &in = input_specs.at(range.first); + ctx.EmplaceBackInput(phi::distributed::DistMetaTensor( + phi::make_ddim(in.shape()), in.dist_attr())); + } else { + int start = range.first; + int end = range.second; + ins.reserve(end - start); + for (int i = start; i < end; ++i) { + auto &in = input_specs.at(i); + ins.emplace_back(phi::distributed::DistMetaTensor( + phi::make_ddim(in.shape()), in.dist_attr())); + } + ctx.EmplaceBackInputs(ins); + ins.clear(); + } + } + for (auto &attr : attrs) { + ctx.EmplaceBackAttr(attr); + } + return self.InferForward(ctx); + }) .def("infer_backward", [](const phi::distributed::SpmdRule &self, const std::vector &input_specs, @@ -399,6 +438,44 @@ void BindAutoParallel(py::module *m) { ctx.EmplaceBackAttr(attr); } return self.InferBackward(ctx); + }) + .def("infer_backward", // for op that have vector argument + [](const phi::distributed::SpmdRule &self, + const std::vector> &input_ranges, + const std::vector &input_specs, + const std::vector &attrs) { + /* + to distingish between single tensor argument and vector argument of + one tensor: start - end == 0: single tensor start - end == 1: + vector containing one tensor input_ranges: [(0, 0), (1, 3), (3, 4)] + + input_specs: [t0, t1, t2, t3] --> t0, [t1, t2], [t3] + */ + phi::distributed::InferSpmdContext ctx; + paddle::small_vector + ins; + for (auto &range : input_ranges) { + if (range.second - range.first == 0) { + auto &in = input_specs.at(range.first); + ctx.EmplaceBackInput(phi::distributed::DistMetaTensor( + phi::make_ddim(in.shape()), in.dist_attr())); + } else { + int start = range.first; + int end = range.second; + ins.reserve(end - start); + for (int i = start; i < end; ++i) { + auto &in = input_specs.at(i); + ins.emplace_back(phi::distributed::DistMetaTensor( + phi::make_ddim(in.shape()), in.dist_attr())); + } + ctx.EmplaceBackInputs(ins); + ins.clear(); + } + } + for (auto &attr : attrs) { + ctx.EmplaceBackAttr(attr); + } + return self.InferBackward(ctx); }); py::class_(*m, "DistTensorSpec") diff --git a/paddle/fluid/pybind/cuda_streams_py.cc b/paddle/fluid/pybind/cuda_streams_py.cc index 2b8969e1b8181..2a6c639735a2b 100644 --- a/paddle/fluid/pybind/cuda_streams_py.cc +++ b/paddle/fluid/pybind/cuda_streams_py.cc @@ -98,23 +98,22 @@ void BindCudaStream(py::module *m_ptr) { The handle of the CUDA stream. Parameters: - device(paddle.CUDAPlace()|int|None, optional): The device which wanted to allocate the stream. - If device is None or negative integer, device will be the current device. - If device is positive integer, it must less than the device count. Default: None. - - priority(int|None, optional): The priority of stream. The priority can be 1(high) or 2(normal). - If priority is None, the priority is 2(normal). Default: None. + device(paddle.CUDAPlace()|int|None, optional): The device which wanted to allocate the stream. + If device is None or negative integer, device will be the current device. + If device is positive integer, it must less than the device count. Default: None. + priority(int|None, optional): The priority of stream. The priority can be 1(high) or 2(normal). + If priority is None, the priority is 2(normal). Default: None. Examples: - .. code-block:: python + .. code-block:: python - # required: gpu - import paddle - s1 = paddle.device.cuda.Stream(paddle.CUDAPlace(0), 1) - s2 = paddle.device.cuda.Stream(0, 1) - s3 = paddle.device.cuda.Stream() + >>> # doctest: +REQUIRES(env:GPU) + >>> import paddle + >>> s1 = paddle.device.cuda.Stream(paddle.CUDAPlace(0), 1) + >>> s2 = paddle.device.cuda.Stream(0, 1) + >>> s3 = paddle.device.cuda.Stream() - )DOC") + )DOC") #if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP) .def( "wait_event", @@ -122,21 +121,20 @@ void BindCudaStream(py::module *m_ptr) { self.WaitEvent(event.GetRawCudaEvent()); }, R"DOC( - Makes all future work submitted to stream wait for all work captured in event. - - Parameters: - event(CUDAEvent): The event to wait on. + Makes all future work submitted to stream wait for all work captured in event. - Examples: - .. code-block:: python + Parameters: + event(CUDAEvent): The event to wait on. - # required: gpu - import paddle - s = paddle.device.cuda.Stream(paddle.CUDAPlace(0), 1) - event = paddle.device.cuda.Event() - s.wait_event(event) + Examples: + .. code-block:: python - )DOC") + >>> # doctest: +REQUIRES(env:GPU) + >>> import paddle + >>> s = paddle.device.cuda.Stream(paddle.CUDAPlace(0), 1) + >>> event = paddle.device.cuda.Event() + >>> s.wait_event(event) + )DOC") .def( "wait_stream", [](phi::CUDAStream &self, phi::CUDAStream &stream) { @@ -145,53 +143,53 @@ void BindCudaStream(py::module *m_ptr) { self.WaitEvent(event.GetRawCudaEvent()); }, R"DOC( - Synchronizes with the given stream. + Synchronizes with the given stream. - Parameters: - stream(CUDAStream): The stream to synchronize with. + Parameters: + stream(CUDAStream): The stream to synchronize with. - Examples: - .. code-block:: python + Examples: + .. code-block:: python - # required: gpu - import paddle - s1 = paddle.device.cuda.Stream(paddle.CUDAPlace(0), 1) - s2 = paddle.device.cuda.Stream(0, 1) - s1.wait_stream(s2) + >>> # doctest: +REQUIRES(env:GPU) + >>> import paddle + >>> s1 = paddle.device.cuda.Stream(paddle.CUDAPlace(0), 1) + >>> s2 = paddle.device.cuda.Stream(0, 1) + >>> s1.wait_stream(s2) - )DOC") + )DOC") .def( "query", [](phi::CUDAStream &self) { return self.Query(); }, R"DOC( - Return the status whether if all operations in stream have completed. + Return the status whether if all operations in stream have completed. - Returns: A boolean value. + Returns: A boolean value. - Examples: - .. code-block:: python + Examples: + .. code-block:: python - # required: gpu - import paddle - s = paddle.device.cuda.Stream(paddle.CUDAPlace(0), 1) - is_done = s.query() + >>> # doctest: +REQUIRES(env:GPU) + >>> import paddle + >>> s = paddle.device.cuda.Stream(paddle.CUDAPlace(0), 1) + >>> is_done = s.query() - )DOC") + )DOC") .def( "synchronize", [](phi::CUDAStream &self) { self.Synchronize(); }, R"DOC( - Waits for stream tasks to complete. + Waits for stream tasks to complete. - Examples: - .. code-block:: python + Examples: + .. code-block:: python - # required: gpu - import paddle - s = paddle.device.cuda.Stream(paddle.CUDAPlace(0), 1) - s.synchronize() + >>> # doctest: +REQUIRES(env:GPU) + >>> import paddle + >>> s = paddle.device.cuda.Stream(paddle.CUDAPlace(0), 1) + >>> s.synchronize() - )DOC") + )DOC") .def( "record_event", [](phi::CUDAStream &self, paddle::platform::CudaEvent *event) { @@ -202,24 +200,24 @@ void BindCudaStream(py::module *m_ptr) { return event; }, R"DOC( - Record a CUDA event in the stream. + Record a CUDA event in the stream. - Parameters: - event(CUDAEvent, optional): The event to be record. If event is None, a new event is created. - Default: None. + Parameters: + event(CUDAEvent, optional): The event to be record. If event is None, a new event is created. + Default: None. - Returns: - The record event. + Returns: + The record event. - Examples: - .. code-block:: python + Examples: + .. code-block:: python - # required: gpu - import paddle - s = paddle.device.cuda.Stream(paddle.CUDAPlace(0), 1) - event = s.record_event() + >>> # doctest: +REQUIRES(env:GPU) + >>> import paddle + >>> s = paddle.device.cuda.Stream(paddle.CUDAPlace(0), 1) + >>> event = s.record_event() - )DOC", + )DOC", py::arg("event") = nullptr) .def_property_readonly( "cuda_stream", @@ -228,21 +226,21 @@ void BindCudaStream(py::module *m_ptr) { return reinterpret_cast(self.raw_stream()); }, R"DOC( - retrun the raw cuda stream of type cudaStream_t as type int. + retrun the raw cuda stream of type cudaStream_t as type int. - Examples: - .. code-block:: python + Examples: + .. code-block:: python - # required: gpu - import paddle - import ctypes - cuda_stream = paddle.device.cuda.current_stream().cuda_stream - print(cuda_stream) + >>> # doctest: +REQUIRES(env:GPU) + >>> import paddle + >>> import ctypes + >>> cuda_stream = paddle.device.cuda.current_stream().cuda_stream + >>> print(cuda_stream) - ptr = ctypes.c_void_p(cuda_stream) # convert back to void* - print(ptr) + >>> ptr = ctypes.c_void_p(cuda_stream) # convert back to void* + >>> print(ptr) - )DOC") + )DOC") .def_property_readonly("place", [](phi::CUDAStream &self) { return platform::CUDAPlace(self.place()); @@ -322,18 +320,18 @@ void BindCudaStream(py::module *m_ptr) { The handle of the CUDA event. Parameters: - enable_timing(bool, optional): Whether the event will measure time. Default: False. - blocking(bool, optional): Whether the wait() func will be blocking. Default: False; - interprocess(bool, optional): Whether the event can be shared between processes. Default: False. + enable_timing(bool, optional): Whether the event will measure time. Default: False. + blocking(bool, optional): Whether the wait() func will be blocking. Default: False; + interprocess(bool, optional): Whether the event can be shared between processes. Default: False. Examples: - .. code-block:: python + .. code-block:: python - # required: gpu - import paddle - event = paddle.device.cuda.Event() + >>> # doctest: +REQUIRES(env:GPU) + >>> import paddle + >>> event = paddle.device.cuda.Event() - )DOC") + )DOC") #if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP) .def( "record", @@ -347,17 +345,18 @@ void BindCudaStream(py::module *m_ptr) { Records the event in the given stream. Parameters: - stream(CUDAStream, optional): The handle of CUDA stream. If None, the stream is the current stream. Default: None. + stream(CUDAStream, optional): The handle of CUDA stream. If None, the stream is the current stream. Default: None. Examples: - .. code-block:: python + .. code-block:: python - # required: gpu - import paddle - event = paddle.device.cuda.Event() - event.record() + >>> # doctest: +REQUIRES(env:GPU) + >>> import paddle + >>> paddle.device.set_device('gpu') + >>> event = paddle.device.cuda.Event() + >>> event.record() - )DOC", + )DOC", py::arg("stream") = nullptr) .def( "query", @@ -368,14 +367,15 @@ void BindCudaStream(py::module *m_ptr) { Returns: A boolean which indicates all work currently captured by the event has been completed. Examples: - .. code-block:: python + .. code-block:: python - # required: gpu - import paddle - event = paddle.device.cuda.Event() - is_done = event.query() + >>> # doctest: +REQUIRES(env:GPU) + >>> import paddle + >>> paddle.device.set_device('gpu') + >>> event = paddle.device.cuda.Event() + >>> is_done = event.query() - )DOC") + )DOC") .def( "synchronize", [](paddle::platform::CudaEvent &self) { self.Synchronize(); }, @@ -383,14 +383,15 @@ void BindCudaStream(py::module *m_ptr) { Waits for an event to complete. Examples: - .. code-block:: python + .. code-block:: python - # required: gpu - import paddle - event = paddle.device.cuda.Event() - event.synchronize() + >>> # doctest: +REQUIRES(env:GPU) + >>> import paddle + >>> paddle.device.set_device('gpu') + >>> event = paddle.device.cuda.Event() + >>> event.synchronize() - )DOC") + )DOC") #endif .def( "__init__", diff --git a/paddle/fluid/pybind/custom_device_py.cc b/paddle/fluid/pybind/custom_device_py.cc index 0f0caa7fcdd0f..15415a86db422 100644 --- a/paddle/fluid/pybind/custom_device_py.cc +++ b/paddle/fluid/pybind/custom_device_py.cc @@ -110,29 +110,26 @@ void BindCustomDevicePy(py::module *m_ptr) { The handle of the custom device stream. Parameters: - device(paddle.CustomPlace()|str): The device which wanted to allocate the stream. - - device_id(int, optional): The id of the device which wanted to allocate the stream. - If device is None or negative integer, device will be the current device. - If device is positive integer, it must less than the device count. Default: None. - - priority(int|None, optional): The priority of stream. The priority can be 1(high) or 2(normal). - If priority is None, the priority is 2(normal). Default: None. - - blocking(int|None, optional): Whether the stream is executed synchronously. Default: False. + device(paddle.CustomPlace()|str): The device which wanted to allocate the stream. + device_id(int, optional): The id of the device which wanted to allocate the stream. + If device is None or negative integer, device will be the current device. + If device is positive integer, it must less than the device count. Default: None. + priority(int|None, optional): The priority of stream. The priority can be 1(high) or 2(normal). + If priority is None, the priority is 2(normal). Default: None. + blocking(int|None, optional): Whether the stream is executed synchronously. Default: False. Examples: - .. code-block:: python + .. code-block:: python - # required: custom_device - import paddle - s3 = paddle.device.custom.Stream('custom_cpu') - s2 = paddle.device.custom.Stream('custom_cpu', 0) - s1 = paddle.device.custom.Stream(paddle.CustomPlace('custom_cpu')) - s1 = paddle.device.custom.Stream(paddle.CustomPlace('custom_cpu'), 1) - s1 = paddle.device.custom.Stream(paddle.CustomPlace('custom_cpu'), 1, True) + >>> # doctest: +REQUIRES(env:CUSTOM_DEVICE) + >>> import paddle + >>> s3 = paddle.device.custom.Stream('custom_cpu') + >>> s2 = paddle.device.custom.Stream('custom_cpu', 0) + >>> s1 = paddle.device.custom.Stream(paddle.CustomPlace('custom_cpu')) + >>> s1 = paddle.device.custom.Stream(paddle.CustomPlace('custom_cpu'), 1) + >>> s1 = paddle.device.custom.Stream(paddle.CustomPlace('custom_cpu'), 1, True) - )DOC") + )DOC") .def( "__init__", [](phi::stream::Stream &self, @@ -196,22 +193,22 @@ void BindCustomDevicePy(py::module *m_ptr) { #endif }, R"DOC( - Makes all future work submitted to stream wait for all work captured in event. + Makes all future work submitted to stream wait for all work captured in event. - Parameters: - event(CustomDeviceEvent): The event to wait on. + Parameters: + event(CustomDeviceEvent): The event to wait on. - Examples: - .. code-block:: python + Examples: + .. code-block:: python - # required: custom_device - import paddle - place = paddle.CustomPlace('custom_cpu', 0) - s = paddle.device.custom.Stream(place) - event = paddle.device.custom.Event(place) - s.wait_event(event) + >>> # doctest: +REQUIRES(env:CUSTOM_DEVICE) + >>> import paddle + >>> place = paddle.CustomPlace('custom_cpu', 0) + >>> s = paddle.device.custom.Stream(place) + >>> event = paddle.device.custom.Event(place) + >>> s.wait_event(event) - )DOC") + )DOC") .def( "wait_stream", [](const phi::stream::Stream &self, phi::stream::Stream *other) { @@ -227,22 +224,22 @@ void BindCustomDevicePy(py::module *m_ptr) { #endif }, R"DOC( - Synchronizes with the given stream. + Synchronizes with the given stream. - Parameters: - stream(CUDAStream): The stream to synchronize with. + Parameters: + stream(CUDAStream): The stream to synchronize with. - Examples: - .. code-block:: python + Examples: + .. code-block:: python - # required: custom_device - import paddle - place = paddle.CustomPlace('custom_cpu', 0) - s1 = paddle.device.custom.Stream(place) - s2 = paddle.device.custom.Stream(place) - s1.wait_stream(s2) + >>> # doctest: +REQUIRES(env:CUSTOM_DEVICE) + >>> import paddle + >>> place = paddle.CustomPlace('custom_cpu', 0) + >>> s1 = paddle.device.custom.Stream(place) + >>> s2 = paddle.device.custom.Stream(place) + >>> s1.wait_stream(s2) - )DOC") + )DOC") .def( "query", [](const phi::stream::Stream &self) { @@ -255,20 +252,21 @@ void BindCustomDevicePy(py::module *m_ptr) { #endif }, R"DOC( - Return the status whether if all operations in stream have completed. + Return the status whether if all operations in stream have completed. - Returns: A boolean value. + Returns: + A boolean value. - Examples: - .. code-block:: python + Examples: + .. code-block:: python - # required: custom_device - import paddle - place = paddle.CustomPlace('custom_cpu', 0) - s = paddle.device.custom.Stream(place) - is_done = s.query() + >>> # doctest: +REQUIRES(env:CUSTOM_DEVICE) + >>> import paddle + >>> place = paddle.CustomPlace('custom_cpu', 0) + >>> s = paddle.device.custom.Stream(place) + >>> is_done = s.query() - )DOC") + )DOC") .def( "synchronize", [](const phi::stream::Stream &self) { @@ -281,18 +279,18 @@ void BindCustomDevicePy(py::module *m_ptr) { #endif }, R"DOC( - Waits for stream tasks to complete. + Waits for stream tasks to complete. - Examples: - .. code-block:: python + Examples: + .. code-block:: python - # required: custom_device - import paddle - place = paddle.CustomPlace('custom_cpu', 0) - s = paddle.device.custom.Stream(place) - s.synchronize() + >>> # doctest: +REQUIRES(env:CUSTOM_DEVICE) + >>> import paddle + >>> place = paddle.CustomPlace('custom_cpu', 0) + >>> s = paddle.device.custom.Stream(place) + >>> s.synchronize() - )DOC") + )DOC") .def( "record_event", [](const phi::stream::Stream &self, phi::event::Event *event) { @@ -310,25 +308,25 @@ void BindCustomDevicePy(py::module *m_ptr) { #endif }, R"DOC( - Record an event in the stream. + Record an event in the stream. - Parameters: - event(CustomDeviceEvent, optional): The event to be record. If event is None, a new event is created. - Default: None. + Parameters: + event(CustomDeviceEvent, optional): The event to be record. If event is None, a new event is created. + Default: None. - Returns: - The record event. + Returns: + The record event. - Examples: - .. code-block:: python + Examples: + .. code-block:: python - # required: custom_device - import paddle - place = paddle.CustomPlace('custom_cpu', 0) - s = paddle.device.custom.Stream(place) - event = s.record_event() + >>> # doctest: +REQUIRES(env:CUSTOM_DEVICE) + >>> import paddle + >>> place = paddle.CustomPlace('custom_cpu', 0) + >>> s = paddle.device.custom.Stream(place) + >>> event = s.record_event() - )DOC", + )DOC", py::arg("event") = nullptr) .def_property_readonly( "raw_stream", @@ -343,21 +341,21 @@ void BindCustomDevicePy(py::module *m_ptr) { #endif }, R"DOC( - return the raw stream of type CustomDeviceStream as type int. + return the raw stream of type CustomDeviceStream as type int. - Examples: - .. code-block:: python + Examples: + .. code-block:: python - # required: custom_device - import paddle - import ctypes - stream = paddle.device.custom.current_stream().raw_stream - print(stream) + >>> # doctest: +REQUIRES(env:CUSTOM_DEVICE) + >>> import paddle + >>> import ctypes + >>> stream = paddle.device.custom.current_stream().raw_stream + >>> print(stream) - ptr = ctypes.c_void_p(stream) # convert back to void* - print(ptr) + >>> ptr = ctypes.c_void_p(stream) # convert back to void* + >>> print(ptr) - )DOC") + )DOC") .def_property_readonly("place", [](const phi::stream::Stream &self) { #ifdef PADDLE_WITH_CUSTOM_DEVICE return reinterpret_cast(self.GetPlace()); @@ -373,27 +371,23 @@ void BindCustomDevicePy(py::module *m_ptr) { The handle of the custom device event. Parameters: - device(paddle.CustomPlace()|str): The device which wanted to allocate the stream. - - device_id(int, optional): The id of the device which wanted to allocate the stream. - If device is None or negative integer, device will be the current device. - If device is positive integer, it must less than the device count. Default: None. - - enable_timing(bool, optional): Whether the event will measure time. Default: False. - - blocking(bool, optional): Whether the wait() func will be blocking. Default: False; - - interprocess(bool, optional): Whether the event can be shared between processes. Default: False. + device(paddle.CustomPlace()|str): The device which wanted to allocate the stream. + device_id(int, optional): The id of the device which wanted to allocate the stream. + If device is None or negative integer, device will be the current device. + If device is positive integer, it must less than the device count. Default: None. + enable_timing(bool, optional): Whether the event will measure time. Default: False. + blocking(bool, optional): Whether the wait() func will be blocking. Default: False. + interprocess(bool, optional): Whether the event can be shared between processes. Default: False. Examples: - .. code-block:: python + .. code-block:: python - # required: custom_device - import paddle - place = paddle.CustomPlace('custom_cpu', 0) - event = paddle.device.custom.Event(place) + >>> # doctest: +REQUIRES(env:CUSTOM_DEVICE) + >>> import paddle + >>> place = paddle.CustomPlace('custom_cpu', 0) + >>> event = paddle.device.custom.Event(place) - )DOC") + )DOC") .def( "__init__", [](phi::event::Event &self, @@ -483,18 +477,18 @@ void BindCustomDevicePy(py::module *m_ptr) { Records the event in the given stream. Parameters: - stream(CustomDeviceStream, optional): The handle of custom device stream. If None, the stream is the current stream. Default: None. + stream(CustomDeviceStream, optional): The handle of custom device stream. If None, the stream is the current stream. Default: None. Examples: - .. code-block:: python + .. code-block:: python - # required: custom_device - import paddle - place = paddle.CustomPlace('custom_cpu', 0) - event = paddle.device.custom.Event(place) - event.record() + >>> # doctest: +REQUIRES(env:CUSTOM_DEVICE) + >>> import paddle + >>> place = paddle.CustomPlace('custom_cpu', 0) + >>> event = paddle.device.custom.Event(place) + >>> event.record() - )DOC") + )DOC") .def( "query", [](const phi::event::Event &self) { @@ -509,18 +503,19 @@ void BindCustomDevicePy(py::module *m_ptr) { R"DOC( Queries the event's status. - Returns: A boolean which indicates all work currently captured by the event has been completed. + Returns: + A boolean which indicates all work currently captured by the event has been completed. Examples: - .. code-block:: python + .. code-block:: python - # required: custom_device - import paddle - place = paddle.CustomPlace('custom_cpu', 0) - event = paddle.device.cuda.Event(place) - is_done = event.query() + >>> # doctest: +REQUIRES(env:CUSTOM_DEVICE) + >>> import paddle + >>> place = paddle.CustomPlace('custom_cpu', 0) + >>> event = paddle.device.cuda.Event(place) + >>> is_done = event.query() - )DOC") + )DOC") .def( "synchronize", [](const phi::event::Event &self) { @@ -536,15 +531,15 @@ void BindCustomDevicePy(py::module *m_ptr) { Waits for an event to complete. Examples: - .. code-block:: python + .. code-block:: python - # required: custom_device - import paddle - place = paddle.CustomPlace('custom_cpu', 0) - event = paddle.device.custom.Event(place) - event.synchronize() + >>> # doctest: +REQUIRES(env:CUSTOM_DEVICE) + >>> import paddle + >>> place = paddle.CustomPlace('custom_cpu', 0) + >>> event = paddle.device.custom.Event(place) + >>> event.synchronize() - )DOC") + )DOC") .def_property_readonly( "raw_event", [](const phi::event::Event &self) { @@ -558,23 +553,23 @@ void BindCustomDevicePy(py::module *m_ptr) { #endif }, R"DOC( - return the raw event of type CustomDeviceEvent as type int. + return the raw event of type CustomDeviceEvent as type int. - Examples: - .. code-block:: python + Examples: + .. code-block:: python - # required: custom_device - import paddle - import ctypes - place = paddle.CustomPlace('custom_cpu', 0) - event = paddle.device.custom.Event(place) - raw_event = event.raw_event - print(raw_event) + >>> # doctest: +REQUIRES(env:CUSTOM_DEVICE) + >>> import paddle + >>> import ctypes + >>> place = paddle.CustomPlace('custom_cpu', 0) + >>> event = paddle.device.custom.Event(place) + >>> raw_event = event.raw_event + >>> print(raw_event) - ptr = ctypes.c_void_p(raw_event) # convert back to void* - print(ptr) + >>> ptr = ctypes.c_void_p(raw_event) # convert back to void* + >>> print(ptr) - )DOC") + )DOC") .def_property_readonly("place", [](const phi::event::Event &self) { #ifdef PADDLE_WITH_CUSTOM_DEVICE return reinterpret_cast(self.GetPlace()); diff --git a/paddle/fluid/pybind/eager_functions.cc b/paddle/fluid/pybind/eager_functions.cc index d03a20537eee6..e63790a65dfc8 100644 --- a/paddle/fluid/pybind/eager_functions.cc +++ b/paddle/fluid/pybind/eager_functions.cc @@ -64,6 +64,7 @@ typedef SSIZE_T ssize_t; #include "paddle/phi/api/include/operants_manager.h" #include "paddle/phi/api/include/tensor_operants.h" +#include "paddle/phi/api/lib/data_transform.h" #include "paddle/phi/core/flags.h" PHI_DECLARE_string(tensor_operants_mode); @@ -549,12 +550,34 @@ static PyObject* eager_api_run_custom_op(PyObject* self, continue; } if (paddle::framework::detail::IsDuplicableVar(input)) { - ctx.EmplaceBackInputs(std::move(CastPyArg2VectorOfTensor(obj, i + 1))); + std::vector tensors = + std::move(CastPyArg2VectorOfTensor(obj, i + 1)); + for (auto& tensor : tensors) { + if (tensor.initialized() && tensor.is_dense_tensor() && + !std::dynamic_pointer_cast(tensor.impl()) + ->meta() + .is_contiguous()) { + tensor.set_impl(std::make_shared( + std::move(paddle::experimental::Trans2Contiguous( + *(std::dynamic_pointer_cast( + tensor.impl())))))); + } + } + ctx.EmplaceBackInputs(std::move(tensors)); VLOG(7) << "Custom operator add input " << input << " to CustomOpKernelContext. Add vector size = " << ctx.InputRangeAt(i).second - ctx.InputRangeAt(i).first; } else { - ctx.EmplaceBackInput(std::move(CastPyArg2Tensor(obj, i + 1))); + paddle::Tensor tensor = std::move(CastPyArg2Tensor(obj, i + 1)); + if (tensor.initialized() && tensor.is_dense_tensor() && + !std::dynamic_pointer_cast(tensor.impl()) + ->meta() + .is_contiguous()) { + tensor.set_impl(std::make_shared( + std::move(paddle::experimental::Trans2Contiguous(*( + std::dynamic_pointer_cast(tensor.impl())))))); + } + ctx.EmplaceBackInput(std::move(tensor)); VLOG(7) << "Custom operator add input " << input << " to CustomOpKernelContext. Add Tensor for general case."; } diff --git a/paddle/fluid/pybind/eager_legacy_custom_python_api.h b/paddle/fluid/pybind/eager_legacy_custom_python_api.h index 1deb20fbf9b88..1c40ce4275c42 100644 --- a/paddle/fluid/pybind/eager_legacy_custom_python_api.h +++ b/paddle/fluid/pybind/eager_legacy_custom_python_api.h @@ -21,7 +21,7 @@ namespace paddle { namespace pybind { -static PyObject *eager_api_run_program(PyObject *self, +static PyObject *eager_api_run_program(PyObject *self, // TOREMOVE PyObject *args, PyObject *kwargs) { PyThreadState *tstate = nullptr; @@ -61,11 +61,58 @@ static PyObject *eager_api_run_program(PyObject *self, } } +static PyObject *newir_eager_api_run_program(PyObject *self, + PyObject *args, + PyObject *kwargs) { + PyThreadState *tstate = nullptr; + try { + auto X = GetTensorListFromArgs("run_program", "X", args, 0, true); + auto Params = GetTensorListFromArgs("run_program", "Params", args, 1, true); + auto Out = GetTensorPtrListFromArgs("run_program", "Out", args, 2, true); + auto OutScope = + GetScopePtrListFromArgs("run_program", "OutScope", args, 3, false); + auto DOut = GetTensorPtrListFromArgs("run_program", "DOut", args, 4, true); + framework::AttributeMap attrs; + // TODO(zengjinle): support CUDA Graph on eager mode + VLOG(1) << "Start NewIR ConstructAttrMapFromPyArgs"; + + ConstructAttrMapForRunProgram( + "run_program", args, 6, PyTuple_GET_SIZE(args), attrs); + + VLOG(1) << "Finish NewIR ConstructAttrMapFromPyArgs"; + tstate = PyEval_SaveThread(); + newir_run_program_ad_func(X, Params, Out, OutScope, DOut, attrs); + PyEval_RestoreThread(tstate); + tstate = nullptr; + Py_RETURN_NONE; + } catch (paddle::platform::EnforceNotMet &exception) { + if (tstate) { + PyEval_RestoreThread(tstate); + } + std::ostringstream sout; + sout << exception.what(); + sout << " [operator < run_program > error]"; + exception.set_error_str(sout.str()); + ThrowExceptionToPython(std::current_exception()); + return nullptr; + } catch (...) { + if (tstate) { + PyEval_RestoreThread(tstate); + } + ThrowExceptionToPython(std::current_exception()); + return nullptr; + } +} + static PyMethodDef CustomEagerMethods[] = { {"run_program", (PyCFunction)(void (*)(void))eager_api_run_program, METH_VARARGS | METH_KEYWORDS, "C++ interface function for run_program in dygraph."}, + {"newir_run_program", + (PyCFunction)(void (*)(void))newir_eager_api_run_program, + METH_VARARGS | METH_KEYWORDS, + "C++ interface function for run_program in dygraph."}, {nullptr, nullptr, 0, nullptr}}; } // namespace pybind diff --git a/paddle/fluid/pybind/eager_method.cc b/paddle/fluid/pybind/eager_method.cc index 4c1fb0b431070..4046ef525bfd6 100644 --- a/paddle/fluid/pybind/eager_method.cc +++ b/paddle/fluid/pybind/eager_method.cc @@ -1665,14 +1665,17 @@ static PyObject* tensor_method__setitem_eager_tensor(TensorObject* self, // use inplace set_value_ operator if (value_tensor.initialized() && (self->tensor.dtype() != value_tensor.dtype())) { - paddle::small_vector, - egr::kSlotSmallVectorSize> - tmps = {{self->tensor}, {value_tensor}}; - auto amp_dtype = egr::GetAmpDestDtype("set_value", tmps); - self->tensor = egr::EagerAmpAutoCast( - self->tensor.name(), self->tensor, amp_dtype, "set_value"); - value_tensor = egr::EagerAmpAutoCast( - value_tensor.name(), value_tensor, amp_dtype, "set_value"); + if (egr::Controller::Instance().GetAMPLevel() != + paddle::imperative::AmpLevel::O0) { + paddle::small_vector, + egr::kSlotSmallVectorSize> + tmps = {{self->tensor}, {value_tensor}}; + auto amp_dtype = egr::GetAmpDestDtype("set_value", tmps); + self->tensor = egr::EagerAmpAutoCast( + self->tensor.name(), self->tensor, amp_dtype, "set_value"); + value_tensor = egr::EagerAmpAutoCast( + value_tensor.name(), value_tensor, amp_dtype, "set_value"); + } if (self->tensor.dtype() != value_tensor.dtype()) { value_tensor = cast_ad_func(value_tensor, self->tensor.dtype()); } diff --git a/paddle/fluid/pybind/eager_utils.cc b/paddle/fluid/pybind/eager_utils.cc index 95d86f544c4bf..84418058aa9f5 100644 --- a/paddle/fluid/pybind/eager_utils.cc +++ b/paddle/fluid/pybind/eager_utils.cc @@ -11,7 +11,7 @@ limitations under the License. */ #include "paddle/fluid/pybind/eager_utils.h" #include -#include "paddle/ir/core/value.h" +#include "paddle/pir/core/value.h" // Avoid a problem with copysign defined in pyconfig.h on Windows. #ifdef copysign #undef copysign @@ -138,6 +138,25 @@ bool PyObject_CheckIROpResult(PyObject* obj) { return PyObject_TypeCheck(obj, g_ir_opresult_pytype); } +bool PyObject_CheckIRVectorOfOpResult(PyObject* obj) { + if (PyList_Check(obj)) { + Py_ssize_t len = PyList_Size(obj); + PyObject* item = nullptr; + // if obj is [], parse it as std::vector + if (len == 0) { + return false; + } + for (Py_ssize_t i = 0; i < len; i++) { + item = PyList_GetItem(obj, i); + if (!PyObject_CheckIROpResult(item)) { + return false; + } + } + return true; + } else { + return false; + } +} bool CastPyArg2AttrBoolean(PyObject* obj, ssize_t arg_pos) { if (obj == Py_None) { return false; // To be compatible with QA integration testing. Some @@ -888,13 +907,13 @@ PyObject* ToPyObject(const phi::DenseTensor* value) { return obj.ptr(); } -PyObject* ToPyObject(const ir::OpResult& value) { +PyObject* ToPyObject(const pir::OpResult& value) { auto obj = ::pybind11::cast(value); obj.inc_ref(); return obj.ptr(); } -PyObject* ToPyObject(const std::vector& value) { +PyObject* ToPyObject(const std::vector& value) { PyObject* result = PyList_New((Py_ssize_t)value.size()); for (size_t i = 0; i < value.size(); i++) { @@ -1485,13 +1504,13 @@ paddle::experimental::Scalar CastNumpy2Scalar(PyObject* obj, } } -ir::OpResult CastPyArg2OpResult(PyObject* obj, - const std::string& op_type, - size_t arg_pos) { +pir::OpResult CastPyArg2OpResult(PyObject* obj, + const std::string& op_type, + size_t arg_pos) { if (PyObject_TypeCheck(obj, g_ir_opresult_pytype)) { - return ::pybind11::handle(obj).cast(); + return ::pybind11::handle(obj).cast(); } else if (obj == nullptr || obj == Py_None) { - return ir::OpResult(); + return pir::OpResult(); } else { PADDLE_THROW(platform::errors::InvalidArgument( "%s(): argument (position %d) must be " @@ -1502,17 +1521,17 @@ ir::OpResult CastPyArg2OpResult(PyObject* obj, } } -std::vector CastPyArg2VectorOfOpResult(PyObject* obj, - const std::string& op_type, - size_t arg_pos) { - std::vector result_list; +std::vector CastPyArg2VectorOfOpResult( + PyObject* obj, const std::string& op_type, size_t arg_pos) { + std::vector result_list; if (PyList_Check(obj)) { Py_ssize_t len = PyList_Size(obj); PyObject* item = nullptr; for (Py_ssize_t i = 0; i < len; i++) { item = PyList_GetItem(obj, i); if (PyObject_TypeCheck(item, g_ir_opresult_pytype)) { - result_list.emplace_back(::pybind11::handle(item).cast()); + result_list.emplace_back( + ::pybind11::handle(item).cast()); } else if (item == Py_None) { continue; } else { @@ -1531,7 +1550,8 @@ std::vector CastPyArg2VectorOfOpResult(PyObject* obj, for (Py_ssize_t i = 0; i < len; i++) { item = PyTuple_GetItem(obj, i); if (PyObject_TypeCheck(item, g_ir_opresult_pytype)) { - result_list.emplace_back(::pybind11::handle(item).cast()); + result_list.emplace_back( + ::pybind11::handle(item).cast()); } else if (item == Py_None) { continue; } else { @@ -1545,7 +1565,7 @@ std::vector CastPyArg2VectorOfOpResult(PyObject* obj, } } } else if (PyObject_TypeCheck(obj, g_ir_opresult_pytype)) { - return {::pybind11::handle(obj).cast()}; + return {::pybind11::handle(obj).cast()}; } else if (obj == Py_None) { return {}; } else { @@ -1697,7 +1717,6 @@ paddle::experimental::IntArray CastPyArg2IntArray(PyObject* obj, arg_pos + 1, ((PyTypeObject*)obj->ob_type)->tp_name)); // NOLINT } - // Fake a IntArray return paddle::experimental::IntArray({1}); } diff --git a/paddle/fluid/pybind/eager_utils.h b/paddle/fluid/pybind/eager_utils.h index ad7ec2d42c437..ba2368c9b6bb2 100644 --- a/paddle/fluid/pybind/eager_utils.h +++ b/paddle/fluid/pybind/eager_utils.h @@ -29,7 +29,6 @@ typedef SSIZE_T ssize_t; #include "paddle/fluid/framework/tensor.h" #include "paddle/fluid/jit/function.h" #include "paddle/fluid/platform/place.h" -#include "paddle/ir/core/value.h" #include "paddle/phi/common/backend.h" #include "paddle/phi/common/data_type.h" #include "paddle/phi/common/int_array.h" @@ -38,6 +37,7 @@ typedef SSIZE_T ssize_t; #include "paddle/phi/core/distributed/auto_parallel/dist_attr.h" #include "paddle/phi/core/distributed/auto_parallel/dist_tensor.h" #include "paddle/phi/core/selected_rows.h" +#include "paddle/pir/core/op_result.h" #include "paddle/utils/pybind.h" #include "pybind11/pybind11.h" #include "pybind11/stl.h" @@ -57,6 +57,7 @@ bool PyObject_CheckLongOrConvertToLong(PyObject** obj); bool PyObject_CheckFloatOrConvertToFloat(PyObject** obj); bool PyObject_CheckStr(PyObject* obj); bool PyObject_CheckIROpResult(PyObject* obj); +bool PyObject_CheckIRVectorOfOpResult(PyObject* obj); bool CastPyArg2AttrBoolean(PyObject* obj, ssize_t arg_pos); int CastPyArg2AttrInt(PyObject* obj, ssize_t arg_pos); int64_t CastPyArg2AttrLong(PyObject* obj, ssize_t arg_pos); @@ -75,12 +76,11 @@ std::vector CastPyArg2VectorOfInt(PyObject* obj, size_t arg_pos); std::vector CastPyArg2VectorOfInt64(PyObject* obj, size_t arg_pos); std::vector CastPyArg2VectorOfSize_t(PyObject* obj, size_t arg_pos); std::vector CastPyArg2VectorOfFloat(PyObject* obj, size_t arg_pos); -ir::OpResult CastPyArg2OpResult(PyObject* obj, - const std::string& op_type, - size_t arg_pos); -std::vector CastPyArg2VectorOfOpResult(PyObject* obj, - const std::string& op_type, - size_t arg_pos); +pir::OpResult CastPyArg2OpResult(PyObject* obj, + const std::string& op_type, + size_t arg_pos); +std::vector CastPyArg2VectorOfOpResult( + PyObject* obj, const std::string& op_type, size_t arg_pos); std::vector> CastPyArg2VectorOfVectorOfSize_t( PyObject* obj, size_t arg_pos); framework::proto::VarType::Type CastPyArg2ProtoType(PyObject* obj, @@ -131,8 +131,8 @@ PyObject* ToPyObject(const paddle::framework::Vocab& value); PyObject* ToPyObject(std::shared_ptr grad_node); -PyObject* ToPyObject(const ir::OpResult& value); -PyObject* ToPyObject(const std::vector& value); +PyObject* ToPyObject(const pir::OpResult& value); +PyObject* ToPyObject(const std::vector& value); class PyTensorHook : public egr::TensorHook { public: diff --git a/paddle/fluid/pybind/inference_api.cc b/paddle/fluid/pybind/inference_api.cc index 1690d738a2c60..66f24b6f03fc3 100644 --- a/paddle/fluid/pybind/inference_api.cc +++ b/paddle/fluid/pybind/inference_api.cc @@ -245,6 +245,8 @@ paddle_infer::PlaceType ToPaddleInferPlace( return paddle_infer::PlaceType::kGPU; } else if (allocation_type == phi::AllocationType::XPU) { return paddle_infer::PlaceType::kXPU; + } else if (allocation_type == phi::AllocationType::CUSTOM) { + return paddle_infer::PlaceType::kCUSTOM; } else { return paddle_infer::PlaceType::kCPU; } @@ -975,19 +977,19 @@ void BindAnalysisConfig(py::module *m) { .def("disable_mkldnn_fc_passes", &AnalysisConfig::DisableMkldnnFcPasses, R"DOC( - Disable Mkldnn FC - Args: + Disable Mkldnn FC + Returns: None. - Returns: - None. - Examples: - .. code-block:: python - from paddle.inference import Config - - config = Config("") - config.enable_mkldnn() - config.disable_mkldnn_fc_passes() - )DOC") + + Examples: + .. code-block:: python + + >>> from paddle.inference import Config + + >>> config = Config("") + >>> config.enable_mkldnn() + >>> config.disable_mkldnn_fc_passes() + )DOC") #endif .def("set_mkldnn_op", &AnalysisConfig::SetMKLDNNOp) .def("set_model_buffer", &AnalysisConfig::SetModelBuffer) diff --git a/paddle/fluid/pybind/ir.cc b/paddle/fluid/pybind/ir.cc index 4dc36fe785ecc..465a8719b3c7f 100644 --- a/paddle/fluid/pybind/ir.cc +++ b/paddle/fluid/pybind/ir.cc @@ -25,38 +25,40 @@ #include "paddle/fluid/pybind/pybind_variant_caster.h" #include "paddle/fluid/framework/program_desc.h" -#include "paddle/fluid/ir/dialect/paddle_dialect/interface/op_yaml_info.h" -#include "paddle/fluid/ir/dialect/paddle_dialect/ir/api_builder.h" -#include "paddle/fluid/ir/dialect/paddle_dialect/ir/pd_dialect.h" -#include "paddle/fluid/ir/dialect/paddle_dialect/ir/pd_type.h" -#include "paddle/fluid/ir/dialect/paddle_dialect/utils/utils.h" -#include "paddle/fluid/ir/transforms/inplace_pass.h" #include "paddle/fluid/ir_adaptor/translator/translate.h" #include "paddle/fluid/ir_adaptor/translator/utils.h" -#include "paddle/ir/core/block.h" -#include "paddle/ir/core/builtin_attribute.h" -#include "paddle/ir/core/program.h" -#include "paddle/ir/core/type.h" -#include "paddle/ir/core/value.h" -#include "paddle/ir/pass/pass.h" -#include "paddle/ir/pass/pass_manager.h" -#include "paddle/ir/pass/pass_registry.h" -#include "paddle/ir/transforms/dead_code_elimination_pass.h" +#include "paddle/fluid/pir/dialect/operator/interface/op_yaml_info.h" +#include "paddle/fluid/pir/dialect/operator/ir/api_builder.h" +#include "paddle/fluid/pir/dialect/operator/ir/op_dialect.h" +#include "paddle/fluid/pir/dialect/operator/ir/op_type.h" +#include "paddle/fluid/pir/dialect/operator/ir/pd_api.h" +#include "paddle/fluid/pir/dialect/operator/ir/pd_op.h" +#include "paddle/fluid/pir/dialect/operator/utils/utils.h" +#include "paddle/fluid/pir/transforms/inplace_pass.h" #include "paddle/phi/core/enforce.h" +#include "paddle/pir/core/block.h" +#include "paddle/pir/core/builtin_attribute.h" +#include "paddle/pir/core/program.h" +#include "paddle/pir/core/type.h" +#include "paddle/pir/core/value.h" +#include "paddle/pir/pass/pass.h" +#include "paddle/pir/pass/pass_manager.h" +#include "paddle/pir/pass/pass_registry.h" +#include "paddle/pir/transforms/dead_code_elimination_pass.h" #include "pybind11/stl.h" namespace py = pybind11; -using ir::Block; -using ir::Operation; -using ir::OpOperand; -using ir::OpResult; -using ir::Pass; -using ir::PassManager; -using ir::Program; -using ir::Type; -using ir::Value; using paddle::dialect::APIBuilder; using paddle::dialect::DenseTensorType; +using pir::Block; +using pir::Operation; +using pir::OpOperand; +using pir::OpResult; +using pir::Pass; +using pir::PassManager; +using pir::Program; +using pir::Type; +using pir::Value; using pybind11::return_value_policy; USE_PASS(dead_code_elimination); @@ -69,6 +71,26 @@ PyTypeObject *g_ir_opresult_pytype = nullptr; void BindOpsAPI(pybind11::module *module); +inline int64_t GetProgramInt64Attr(const std::shared_ptr &program, + const std::string &attr_name, + int64_t default_value = 0) { + auto op = program->module_op(); + if (op->HasAttribute(attr_name)) { + auto val = op->attribute(attr_name).dyn_cast().data(); + return val; + } else { + return default_value; + } +} + +inline void SetProgramInt64Attr(std::shared_ptr program, + const std::string &attr_name, + int64_t value) { + auto op = program->module_op(); + op->set_attribute( + attr_name, pir::Int64Attribute::get(pir::IrContext::Instance(), value)); +} + void BindProgram(py::module *m) { py::class_> program(*m, "Program", R"DOC( Create Python Program. Program is an abstraction of model structure, divided into @@ -111,27 +133,42 @@ void BindProgram(py::module *m) { print("start up program is: {}".format(startup_program)) )DOC"); program - .def( - "__init__", - [](Program &self) { new (&self) Program(ir::IrContext::Instance()); }) + .def("__init__", + [](Program &self) { + new (&self) Program(pir::IrContext::Instance()); + }) .def("__str__", [](const std::shared_ptr &self) { std::ostringstream print_stream; self->Print(print_stream); return print_stream.str(); }) + .def("__repr__", + [](const std::shared_ptr &self) { + std::ostringstream print_stream; + self->Print(print_stream); + return print_stream.str(); + }) .def("parameters_num", [](const std::shared_ptr &self) { return self->parameters_num(); }) .def( - "block", + "global_block", [](std::shared_ptr self) { return self->block(); }, return_value_policy::reference) .def( - "block", + "global_block", [](const std::shared_ptr &self) { return self->block(); }, - return_value_policy::reference); + return_value_policy::reference) + .def_property( + "random_seed", + [](const std::shared_ptr &self) { + return GetProgramInt64Attr(self, "random_seed", 0); + }, + [](std::shared_ptr self, int64_t random_seed) { + SetProgramInt64Attr(self, "random_seed", random_seed); + }); } void BindBlock(py::module *m) { @@ -143,8 +180,10 @@ void BindBlock(py::module *m) { use `Program.block()` to get a block. )DOC"); block.def("front", &Block::front, return_value_policy::reference) - .def("get_parent_program", - [](Block &self) { return self.GetParentOp()->GetParentProgram(); }) + .def_property_readonly( + "program", + [](Block &self) { return self.GetParentOp()->GetParentProgram(); }, + return_value_policy::reference) .def_property_readonly( "ops", [](Block &self) -> py::list { @@ -169,7 +208,26 @@ void BindBlock(py::module *m) { Returns: None - )DOC"); + )DOC") + .def("all_parameters", [](Block &self) -> py::list { + py::list param_list; + for (auto iter = self.begin(); iter != self.end(); iter++) { + auto op = *iter; + if (op->HasAttribute(kAttrIsPersisable)) { + auto attrs = op->attribute(kAttrIsPersisable) + .dyn_cast() + .AsVector(); + for (uint32_t i = 0; i < attrs.size(); i++) { + bool is_persistable = + attrs[i].dyn_cast().data(); + if (is_persistable) { + param_list.append(op->result(i)); + } + } + } + } + return param_list; + }); } void BindOperation(py::module *m) { @@ -284,10 +342,10 @@ void BindValue(py::module *m) { .def("__eq__", &Value::operator==) .def("__eq__", [](Value &self, OpResult &other) { - return self.impl() == other.value_impl(); + return self.impl() == other.Value::impl(); }) .def("__hash__", - [](const Value &self) { return std::hash{}(self); }); + [](const Value &self) { return std::hash{}(self); }); } void BindOpOperand(py::module *m) { @@ -311,37 +369,36 @@ void BindOpOperand(py::module *m) { .def("owner", &OpOperand::owner, return_value_policy::reference); } -bool GetStopGradient(const OpResult &self) { +bool GetOpResultBoolAttr(const OpResult &self, const std::string &attr_name) { auto *defining_op = self.owner(); - if (defining_op->HasAttribute(kAttrStopGradients)) { - auto stop_gradients = defining_op->attribute(kAttrStopGradients) - .dyn_cast() - .AsVector(); - return stop_gradients[self.GetResultIndex()] - .dyn_cast() - .data(); + if (defining_op->HasAttribute(attr_name)) { + auto attrs = defining_op->attribute(attr_name) + .dyn_cast() + .AsVector(); + return attrs[self.GetResultIndex()].dyn_cast().data(); } else { return false; } } -void SetStopGradient(const OpResult &self, bool stop_gradient) { +void SetOpResultBoolAttr(const OpResult &self, + const std::string &attr_name, + bool value) { auto *defining_op = self.owner(); - std::vector stop_gradients; - if (defining_op->HasAttribute(kAttrStopGradients)) { - stop_gradients = defining_op->attribute(kAttrStopGradients) - .dyn_cast() - .AsVector(); + std::vector attrs; + if (defining_op->HasAttribute(attr_name)) { + attrs = defining_op->attribute(attr_name) + .dyn_cast() + .AsVector(); } else { - stop_gradients = std::vector( + attrs = std::vector( defining_op->num_results(), - ir::BoolAttribute::get(ir::IrContext::Instance(), false)); + pir::BoolAttribute::get(pir::IrContext::Instance(), false)); } - stop_gradients[self.GetResultIndex()] = - ir::BoolAttribute::get(ir::IrContext::Instance(), stop_gradient); + attrs[self.GetResultIndex()] = + pir::BoolAttribute::get(pir::IrContext::Instance(), value); defining_op->set_attribute( - kAttrStopGradients, - ir::ArrayAttribute::get(ir::IrContext::Instance(), stop_gradients)); + attr_name, pir::ArrayAttribute::get(pir::IrContext::Instance(), attrs)); } void BindOpResult(py::module *m) { @@ -356,24 +413,98 @@ void BindOpResult(py::module *m) { op_result.def("__eq__", &OpResult::operator==) .def("__eq__", [](OpResult &self, Value &other) { - return self.value_impl() == other.impl(); + return self.Value::impl() == other.impl(); + }) + .def("__neg__", + [](OpResult &self) { + return paddle::dialect::scale(self, -1.0, 0.0, true); + }) + .def("__add__", + [](OpResult &self, OpResult &other) { + return paddle::dialect::add(self, other); + }) + .def("__sub__", + [](OpResult &self, OpResult &other) { + return paddle::dialect::subtract(self, other); + }) + .def("__mul__", + [](OpResult &self, OpResult &other) { + return paddle::dialect::multiply(self, other); + }) + .def("__truediv__", + [](OpResult &self, OpResult &other) { + return paddle::dialect::divide(self, other); + }) + .def("__lt__", + [](OpResult &self, OpResult &other) { + return paddle::dialect::less_than(self, other); + }) + .def("__le__", + [](OpResult &self, OpResult &other) { + return paddle::dialect::less_equal(self, other); + }) + .def("__gt__", + [](OpResult &self, OpResult &other) { + return paddle::dialect::greater_than(self, other); + }) + .def("__ge__", + [](OpResult &self, OpResult &other) { + return paddle::dialect::greater_equal(self, other); }) .def("__hash__", [](OpResult &self) { - return std::hash{}(self.dyn_cast()); + return std::hash{}(self.dyn_cast()); }) .def("get_defining_op", &OpResult::GetDefiningOp, return_value_policy::reference) + .def_property_readonly( + "block", + [](OpResult &self) { return self.GetDefiningOp()->GetParent(); }, + return_value_policy::reference) + .def_property_readonly( + "name", + [](OpResult &self) { + if (self.GetDefiningOp()->name() == "builtin.get_parameter") { + auto param_name = self.GetDefiningOp() + ->attributes() + .at("parameter_name") + .dyn_cast() + .AsString(); + return param_name; + } else { + PADDLE_THROW(phi::errors::InvalidArgument( + "Currently, we can only get name of OpResult that is " + "persistable")); + } + }) .def("first_use", &OpResult::first_use, return_value_policy::reference) .def("has_one_use", &Value::HasOneUse) .def("use_empty", &OpResult::use_empty) .def("type", &OpResult::type) + .def("is_dense_tensor_type", + [](OpResult &self) { + if (self.type().isa()) { + return true; + } else { + return false; + } + }) .def_property( "stop_gradient", - [](OpResult &self) { return GetStopGradient(self); }, + [](OpResult &self) { + return GetOpResultBoolAttr(self, kAttrStopGradients); + }, [](OpResult &self, bool stop_gradient) { - SetStopGradient(self, stop_gradient); + SetOpResultBoolAttr(self, kAttrStopGradients, stop_gradient); + }) + .def_property( + "is_persistable", + [](OpResult &self) { + return GetOpResultBoolAttr(self, kAttrIsPersisable); + }, + [](OpResult &self, bool is_persistable) { + SetOpResultBoolAttr(self, kAttrIsPersisable, is_persistable); }) .def_property( "shape", @@ -417,7 +548,324 @@ void BindType(py::module *m) { }); } +Operation *BuildOpFrom( + const Operation *to_copy_op, + std::unordered_map &value_map) { // NOLINT + pir::OperationArgument to_create_argument(to_copy_op->info()); + to_create_argument.attributes = to_copy_op->attributes(); + + auto origin_results = to_copy_op->results(); + std::transform(origin_results.begin(), + origin_results.end(), + std::back_inserter(to_create_argument.output_types), + [](const pir::OpResult &r) { + // OpResult -> OpType + return r.type(); + }); + + // transform by value_map dict. + auto origin_operands = to_copy_op->operands(); + std::transform(origin_operands.begin(), + origin_operands.end(), + std::back_inserter(to_create_argument.inputs), + [&value_map](const pir::OpOperand &operand) { + // Operand -> OpResult + return value_map[operand.source()].impl(); + }); + auto *cloned_op = Operation::Create(std::move(to_create_argument)); + + // update the mapping of value_map. std::transform is a map(func, zip()). + std::vector tmp; + std::transform(origin_results.begin(), + origin_results.end(), + cloned_op->results().begin(), + std::back_inserter(tmp), // NOLINT, just a placeholder. + [&value_map](const OpResult &a, const OpResult &b) { // NOLINT + value_map[a.Value::impl()] = b.Value::impl(); + return 1; + }); + return cloned_op; +} + +std::shared_ptr ProgramClone(const Program &program) { + // Limitation of this function: + // 1. don't support Parameters. + // 2. don't support Regions in operator. + pir::IrContext *ctx = pir::IrContext::Instance(); + auto cloned_program = std::make_shared(ctx); + std::unordered_map value_map; + for (auto &op : *program.block()) { + auto *cloned_op = BuildOpFrom(op, value_map); + cloned_program->block()->push_back(cloned_op); + } + return cloned_program; +} + +std::list::const_iterator list_offset(const Block *block, + int start_idx) { + auto it = block->begin(); + while (start_idx--) ++it; + return it; +} + +template +void range_block_do(const Block *block, std::vector range, F fn) { + for (auto it = list_offset(block, range[0]); + it != list_offset(block, range[1]); + ++it) { + fn(*it); + } +} + +std::vector AnalysisMiddleVariable( + const Program &program, + const std::vector &forward_inputs, + const std::vector &forward_range, + const std::vector &backward_range) { + std::vector middle_values; + + std::unordered_set backward_inputs; + std::unordered_set x_or_param(forward_inputs.begin(), + forward_inputs.end()); + range_block_do( + program.block(), backward_range, [&backward_inputs](Operation *op) { + for (auto &t : op->operands()) { + backward_inputs.insert(t.source()); + } + }); + + range_block_do( + program.block(), + forward_range, + [&middle_values, &backward_inputs, &x_or_param](Operation *op) { + for (auto &t : op->results()) { + auto v = Value(t.Value::impl()); + if (backward_inputs.count(v) && !x_or_param.count(v)) + middle_values.push_back(v); + } + }); + return middle_values; +} + +void mapping_value(const std::vector &origin, + const std::unordered_map &value_map, + std::vector &out) { // NOLINT + std::transform(origin.begin(), + origin.end(), + std::back_inserter(out), + [&value_map](const pir::Value &v) { + if (v.impl() == nullptr) return Value(nullptr); + return value_map.at(v); + }); +} + +using SplitedProgram = std::vector>; +using SplitedAttribute = std::map>; +using SplitedResult = std::pair; + +pir::OpResult FakeOpResult() { + // create a fake opresults to simplify `ForwardBackwardSplit`. + return pir::OpResult(nullptr); +} + +SplitedResult ForwardBackwardSplit( + const Program &program, + const std::vector &op_result_forward_inputs, + const std::vector &op_result_forward_outputs, + const std::vector &op_result_forward_inputs_grads, + const std::vector &op_result_forward_outputs_grads, + const std::vector &forward_range, + const std::vector &backward_range) { + // transform opresult -> value + VLOG(1) << "Start Prepare data structures."; + std::vector forward_inputs, forward_outputs, forward_inputs_grads, + forward_outputs_grads; + + auto op_result_to_value = [](const pir::OpResult &r) { + if (r.impl() == nullptr) return Value(nullptr); + return Value(r.Value::impl()); + }; + + std::transform(op_result_forward_inputs.begin(), + op_result_forward_inputs.end(), + std::back_inserter(forward_inputs), + op_result_to_value); + std::transform(op_result_forward_outputs.begin(), + op_result_forward_outputs.end(), + std::back_inserter(forward_outputs), + op_result_to_value); + std::transform(op_result_forward_inputs_grads.begin(), + op_result_forward_inputs_grads.end(), + std::back_inserter(forward_inputs_grads), + op_result_to_value); + std::transform(op_result_forward_outputs_grads.begin(), + op_result_forward_outputs_grads.end(), + std::back_inserter(forward_outputs_grads), + op_result_to_value); + + std::vector forward_in_out_values; + for (auto &v : std::vector *>( + {&forward_inputs, &forward_outputs})) { + forward_in_out_values.insert( + forward_in_out_values.end(), v->begin(), v->end()); + } + + std::vector fx, fp, fm, fo, bx, bp, bm, bo_g, bx_g, bp_g, bo; + pir::IrContext *ctx = pir::IrContext::Instance(); + auto forward_program = std::make_shared(ctx); + auto backward_program = std::make_shared(ctx); + auto middle_values = AnalysisMiddleVariable( + program, forward_in_out_values, forward_range, backward_range); + std::unordered_map forward_value_map; + std::unordered_map backward_value_map; + pir::Builder backward_builder = pir::Builder(ctx, backward_program->block()); + + // forward program construct. + VLOG(1) << "Before Forward Construct."; + range_block_do(program.block(), + forward_range, + [&forward_value_map, &forward_program](Operation *op) { + auto *cloned_op = BuildOpFrom(op, forward_value_map); + forward_program->block()->push_back(cloned_op); + }); + VLOG(1) << "After Forward Construct."; + + // backward program construc. + // Step1. insert data op for inputs_values and middle_values + int counter = 0; + auto create_data_fn = [&backward_builder, &backward_value_map, &counter]( + const pir::Value &v) { + if (v.impl() == nullptr) { + return; + } + auto value_type = v.type().dyn_cast(); + auto dtype = paddle::dialect::TransToPhiDataType(value_type.dtype()); + auto shape = phi::vectorize(value_type.dims()); + auto place = phi::CPUPlace(); // TODO(xiongkun): how to get default places. + + paddle::dialect::DataOp op = + backward_builder.Build( + std::string("input_") + std::to_string(counter), + shape, + dtype, + place); + counter += 1; + backward_value_map[v] = op->results()[0].Value::impl(); + }; + + auto create_output_fn_forward = [&ctx, + &forward_value_map, + &counter, + &forward_program](const pir::Value &v) { + if (v.impl() == nullptr) { + return; + } + auto op_info = ctx->GetRegisteredOpInfo(pir::SetParameterOp::name()); + pir::AttributeMap attribute_map = { + {"parameter_name", + pir::StrAttribute::get( + ctx, std::string("output_") + std::to_string(counter))}, + }; + pir::Operation *operation = pir::Operation::Create( + {OpResult(forward_value_map[v].impl())}, attribute_map, {}, op_info); + forward_program->block()->push_back(operation); + counter += 1; + }; + + auto create_output_fn_backward = [&ctx, + &backward_value_map, + &counter, + &backward_program](const pir::Value &v) { + if (v.impl() == nullptr) { + return; + } + auto op_info = ctx->GetRegisteredOpInfo(pir::SetParameterOp::name()); + pir::AttributeMap attribute_map = { + {"parameter_name", + pir::StrAttribute::get( + ctx, std::string("output_") + std::to_string(counter))}, + }; + pir::Operation *operation = + pir::Operation::Create({OpResult(backward_value_map.at(v).impl())}, + attribute_map, + {}, + op_info); + backward_program->block()->push_back(operation); + counter += 1; + }; + + counter = 0; + std::for_each(forward_outputs.begin(), forward_outputs.end(), create_data_fn); + std::for_each(forward_inputs.begin(), forward_inputs.end(), create_data_fn); + std::for_each(middle_values.begin(), middle_values.end(), create_data_fn); + std::for_each(forward_outputs_grads.begin(), + forward_outputs_grads.end(), + create_data_fn); + VLOG(1) << "After create pd.data for backward program."; + + counter = 0; + std::for_each( + middle_values.begin(), middle_values.end(), create_output_fn_forward); + std::for_each( + forward_outputs.begin(), forward_outputs.end(), create_output_fn_forward); + + VLOG(1) << "After call create_output_fn"; + // Step2. copy backward ops . + range_block_do(program.block(), + backward_range, + [&backward_value_map, &backward_program](Operation *op) { + auto *cloned_op = BuildOpFrom(op, backward_value_map); + backward_program->block()->push_back(cloned_op); + }); + VLOG(1) << "After call backward copy"; + counter = 0; + std::for_each(forward_inputs_grads.begin(), + forward_inputs_grads.end(), + create_output_fn_backward); + // TODO(xiongkun): add forward parameter grads. + + VLOG(1) << "forward_value_map.size() is " << forward_value_map.size(); + VLOG(1) << "backward_value_map.size() is " << backward_value_map.size(); + std::ostringstream print_stream; + print_stream << "ForwardProgram is :\n"; + forward_program->Print(print_stream); + print_stream << "BackwardProgram is:\n"; + backward_program->Print(print_stream); + VLOG(1) << "Splited Program (fwd | bwd): \n" << print_stream.str(); + + // construct all attributes we needed. + + mapping_value(middle_values, forward_value_map, fm); // write 'fm' + mapping_value(middle_values, backward_value_map, bm); // write 'bm' + mapping_value(forward_inputs, forward_value_map, fx); // write 'fx' + mapping_value(forward_inputs, backward_value_map, bx); // write 'bx' + mapping_value(forward_outputs, forward_value_map, fo); // write 'fo' + mapping_value( + forward_inputs_grads, backward_value_map, bx_g); // write 'fx_g' + mapping_value( + forward_outputs_grads, backward_value_map, bo_g); // write 'bo_g' + mapping_value(forward_outputs, backward_value_map, bo); // write 'bo' + + std::map> attr = {{"fx", fx}, + {"fp", fp}, + {"fm", fm}, + {"fo", fo}, + {"bx", bx}, + {"bp", bp}, + {"bm", bm}, + {"bo_g", bo_g}, + {"bx_g", bx_g}, + {"bp_g", bp_g}, + {"bo", bo}}; + std::vector> programs = {forward_program, + backward_program}; + return std::make_pair(programs, attr); +} + void BindUtils(pybind11::module *m) { + m->def("program_clone", ProgramClone); + m->def("program_split", ForwardBackwardSplit); + m->def("fake_op_result", FakeOpResult); m->def("set_global_program", [](Program *program) { APIBuilder::Instance().SetProgram(program); }); m->def("set_insertion_point", @@ -427,8 +875,8 @@ void BindUtils(pybind11::module *m) { m->def("reset_insertion_point_to_end", []() { APIBuilder::Instance().ResetInsertionPointToEnd(); }); m->def("register_paddle_dialect", []() { - ir::IrContext::Instance() - ->GetOrRegisterDialect(); + pir::IrContext::Instance() + ->GetOrRegisterDialect(); }); m->def( "translate_to_new_ir", @@ -476,7 +924,7 @@ void BindUtils(pybind11::module *m) { m->def( "check_unregistered_ops", [](const framework::ProgramDesc &legacy_program) { - ir::IrContext *ctx = ir::IrContext::Instance(); + pir::IrContext *ctx = pir::IrContext::Instance(); return paddle::translator::CheckUnregisteredOperation(ctx, legacy_program); }, @@ -516,13 +964,13 @@ void BindPassManager(pybind11::module *m) { .def( "__init__", [](PassManager &self, uint8_t opt_level) { - new (&self) PassManager(ir::IrContext::Instance(), opt_level); + new (&self) PassManager(pir::IrContext::Instance(), opt_level); }, py::arg("opt_level") = 2) .def("add_pass", [](PassManager &self, const std::string &pass_name) { self.AddPass( - std::move(ir::PassRegistry::Instance().Get(pass_name))); + std::move(pir::PassRegistry::Instance().Get(pass_name))); }) .def("passes", [](PassManager &self) { diff --git a/paddle/fluid/pybind/jit.cc b/paddle/fluid/pybind/jit.cc index a0e130f40cf64..69b32fca9cd75 100644 --- a/paddle/fluid/pybind/jit.cc +++ b/paddle/fluid/pybind/jit.cc @@ -258,12 +258,22 @@ static PyObject *_custom_eval_frame(PyThreadState *tstate, // Re-enable custom behavior eval_frame_callback_set(callback); VLOG(7) << "Start eval new frame and code."; - auto out = eval_custom_code(tstate, frame, code, throw_flag); + PyObject *out; + if (reinterpret_cast(code) != Py_None) { + out = eval_custom_code(tstate, frame, code, throw_flag); + } else { + out = eval_frame_default(tstate, frame, throw_flag); + } Py_DECREF(result); Py_DECREF(code); return out; } else { - auto out = eval_custom_code(tstate, frame, code, throw_flag); + PyObject *out; + if (reinterpret_cast(code) != Py_None) { + out = eval_custom_code(tstate, frame, code, throw_flag); + } else { + out = eval_frame_default(tstate, frame, throw_flag); + } // Re-enable custom behavior eval_frame_callback_set(callback); Py_DECREF(result); diff --git a/paddle/fluid/pybind/manual_static_op_function.h b/paddle/fluid/pybind/manual_static_op_function.h new file mode 100644 index 0000000000000..ff365e63bb652 --- /dev/null +++ b/paddle/fluid/pybind/manual_static_op_function.h @@ -0,0 +1,89 @@ +// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include "paddle/fluid/pir/dialect/operator/ir/manual_api.h" +#include "paddle/fluid/pybind/eager_utils.h" +#include "paddle/fluid/pybind/exception.h" +#include "paddle/fluid/pybind/op_function_common.h" +#include "paddle/phi/common/int_array.h" +#include "paddle/phi/core/enforce.h" + +namespace paddle { + +namespace pybind { +static PyObject *static_api_get_parameter(PyObject *self, + PyObject *args, + PyObject *kwargs) { + try { + VLOG(6) << "Add get_parameter op into program"; + VLOG(8) << "args count: " << (PyTuple_Size(args) / 2); + + // Parse Attributes + PyObject *name_obj = PyTuple_GET_ITEM(args, 0); + std::string name = CastPyArg2String(name_obj, "name", 0); + PyObject *dtype_obj = PyTuple_GET_ITEM(args, 1); + phi::DataType dtype = CastPyArg2DataTypeDirectly(dtype_obj, "dtype", 1); + PyObject *shape_obj = PyTuple_GET_ITEM(args, 2); + phi::IntArray shape = CastPyArg2IntArray(shape_obj, "shape", 2); + // Call ir static api + auto static_api_out = + paddle::dialect::get_parameter(name, dtype, shape.GetData()); + + return ToPyObject(static_api_out); + } catch (...) { + ThrowExceptionToPython(std::current_exception()); + return nullptr; + } +} + +static PyObject *static_api_set_parameter(PyObject *self, + PyObject *args, + PyObject *kwargs) { + try { + VLOG(6) << "Add set_parameter op into program"; + VLOG(8) << "args count: " << (PyTuple_Size(args) / 2); + + // Get OpResult from args + PyObject *parameter_obj = PyTuple_GET_ITEM(args, 0); + auto parameter = CastPyArg2OpResult(parameter_obj, "parameter", 0); + + // Parse Attributes + PyObject *name_obj = PyTuple_GET_ITEM(args, 1); + std::string name = CastPyArg2String(name_obj, "name", 1); + // Call ir static api + paddle::dialect::set_parameter(parameter, name); + + Py_RETURN_NONE; + } catch (...) { + ThrowExceptionToPython(std::current_exception()); + return nullptr; + } +} + +static PyMethodDef ManualOpsAPI[] = { + {"set_parameter", + (PyCFunction)(void (*)(void))static_api_set_parameter, + METH_VARARGS | METH_KEYWORDS, + "C++ interface function for set_parameter."}, + {"get_parameter", + (PyCFunction)(void (*)(void))static_api_get_parameter, + METH_VARARGS | METH_KEYWORDS, + "C++ interface function for get_parameter."}, + {nullptr, nullptr, 0, nullptr}}; + +} // namespace pybind + +} // namespace paddle diff --git a/paddle/fluid/pybind/op_function_common.cc b/paddle/fluid/pybind/op_function_common.cc index 266578615e352..a1e22b94ce192 100644 --- a/paddle/fluid/pybind/op_function_common.cc +++ b/paddle/fluid/pybind/op_function_common.cc @@ -34,6 +34,8 @@ #include "paddle/fluid/pybind/eager_utils.h" #include "paddle/fluid/pybind/imperative.h" #include "paddle/phi/common/complex.h" +#include "paddle/pir/core/block.h" +#include "paddle/pir/core/value.h" namespace paddle { namespace pybind { @@ -829,6 +831,54 @@ void CastPyArg2AttrBlock(PyObject* obj, attrs[key] = reinterpret_cast(vh[0]); } +void CastPyArg2AttrIRBlock(PyObject* obj, + paddle::framework::AttributeMap& attrs, // NOLINT + const std::string& key, + const std::string& op_type, + ssize_t arg_pos) { + VLOG(1) << "After Process pir::Block*"; + ::pybind11::detail::instance* inst = + (::pybind11::detail::instance*)obj; // NOLINT + void** vh = inst->simple_layout ? inst->simple_value_holder + : &inst->nonsimple.values_and_holders[0]; + attrs[key] = reinterpret_cast<::pir::Block*&>(vh[0]); +} + +void CastPyArg2AttrValues(PyObject* obj, + paddle::framework::AttributeMap& attrs, // NOLINT + const std::string& key, + const std::string& op_type, + ssize_t arg_pos) { + std::vector<::pir::Value> results; + if (PyList_Check(obj)) { + Py_ssize_t len = PyList_Size(obj); + PyObject* item = nullptr; + for (Py_ssize_t i = 0; i < len; i++) { + // TODO(xiongkun): judge OpResult or Value; + item = PyList_GetItem(obj, i); + ::pybind11::detail::instance* inst = + (::pybind11::detail::instance*)item; // NOLINT + void** vh = inst->simple_layout ? inst->simple_value_holder + : &inst->nonsimple.values_and_holders[0]; + ::pir::OpResult* opresult = reinterpret_cast<::pir::OpResult*>(vh[0]); + if (opresult->impl() == nullptr) { + results.emplace_back(pir::Value(nullptr)); + } else { + results.emplace_back(pir::Value(opresult->Value::impl())); + } + } + } else { + PADDLE_THROW(platform::errors::InvalidArgument( + "%s(): argument (position %d) must be " + "a list of int, float, complex, or bool, but got %s", + op_type, + arg_pos + 1, + ((PyTypeObject*)obj->ob_type)->tp_name)); // NOLINT + } + attrs[key] = results; + VLOG(1) << "Pybind: Cast " << results.size() << " Value Finished."; +} + void ConstructAttrMapFromPyArgs( const std::string& op_type, PyObject* args, @@ -847,6 +897,7 @@ void ConstructAttrMapFromPyArgs( PyObject* obj = nullptr; for (ssize_t arg_pos = attr_start; arg_pos < attr_end; arg_pos += 2) { + VLOG(1) << "Start Process " << arg_pos; Py_ssize_t key_len; const char* key_ptr; obj = PyTuple_GET_ITEM(args, arg_pos); @@ -862,6 +913,7 @@ void ConstructAttrMapFromPyArgs( } std::string key(key_ptr, (size_t)key_len); // NOLINT + VLOG(1) << "Start Process " << key; auto iter = attr_type_map->find(key); if (iter == attr_type_map->end()) { continue; @@ -921,6 +973,77 @@ void ConstructAttrMapFromPyArgs( } } +void ConstructAttrMapForRunProgram( + const std::string& op_type, + PyObject* args, + ssize_t attr_start, + ssize_t attr_end, + paddle::framework::AttributeMap& attrs) { // NOLINT + PADDLE_ENFORCE_EQ((attr_end - attr_start) % 2, + 0, + platform::errors::InvalidArgument( + "The number of arguments for attributes should be even " + "but attr_start = %d, attr_end = %d.", + attr_start, + attr_end)); + + PyObject* obj = nullptr; + for (ssize_t arg_pos = attr_start; arg_pos < attr_end; arg_pos += 2) { + VLOG(1) << "Start Process " << arg_pos; + Py_ssize_t key_len; + const char* key_ptr; + obj = PyTuple_GET_ITEM(args, arg_pos); + if (PyObject_CheckString(obj)) { + key_ptr = PyUnicode_AsUTF8AndSize(obj, &key_len); + } else { + PADDLE_THROW(platform::errors::InvalidArgument( + "%s(): argument (position %d) must be str, but got " + "%s", + op_type, + arg_pos, + ((PyTypeObject*)obj->ob_type)->tp_name)); // NOLINT + } + + std::string key(key_ptr, (size_t)key_len); // NOLINT + VLOG(1) << "Start Process " << key; + obj = PyTuple_GET_ITEM(args, arg_pos + 1); + + if (std::set({"cuda_graph_capture_mode"}).count(key)) { + CastPyArg2AttrString(obj, attrs, key, op_type, arg_pos); + } else if (std::set({"global_block", + "forward_global_block", + "backward_global_block"}) + .count(key)) { + CastPyArg2AttrIRBlock(obj, attrs, key, op_type, arg_pos); + } else if (std::set({"is_test", "use_interpretorcore"}) + .count(key)) { + CastPyArg2AttrBoolean(obj, attrs, key, op_type, arg_pos); + } else if (std::set({"start_op_index", + "end_op_index", + "program_id", + "cuda_graph_pool_id"}) + .count(key)) { + CastPyArg2AttrLong(obj, attrs, key, op_type, arg_pos); + } else if (std::set({"fx", + "fp", + "fm", + "fo", + "bx", + "bp", + "bm", + "bo_g", + "bx_g", + "bp_g", + "bo"}) + .count(key)) { + CastPyArg2AttrValues(obj, attrs, key, op_type, arg_pos); + } else { + PADDLE_THROW(platform::errors::InvalidArgument( + "%s is not defined in this function.", key)); // NOLINT + } + } +} + unsigned long GetUnsignedLongFromArgs( // NOLINT const std::string& op_type, const std::string& arg_name, diff --git a/paddle/fluid/pybind/op_function_common.h b/paddle/fluid/pybind/op_function_common.h index a3f4960bbd58b..2d02dd6fb784d 100644 --- a/paddle/fluid/pybind/op_function_common.h +++ b/paddle/fluid/pybind/op_function_common.h @@ -194,6 +194,13 @@ void ConstructAttrMapFromPyArgs( ssize_t attr_end, paddle::framework::AttributeMap& attrs); // NOLINT +void ConstructAttrMapForRunProgram( + const std::string& op_type, + PyObject* args, + ssize_t attr_start, + ssize_t attr_end, + paddle::framework::AttributeMap& attrs); // NOLINT + unsigned long GetUnsignedLongFromArgs( // NOLINT const std::string& op_type, const std::string& arg_name, diff --git a/paddle/fluid/pybind/pybind.cc b/paddle/fluid/pybind/pybind.cc index 056c4b0daadfc..9d1cd87280179 100644 --- a/paddle/fluid/pybind/pybind.cc +++ b/paddle/fluid/pybind/pybind.cc @@ -195,17 +195,18 @@ limitations under the License. */ #include "paddle/fluid/eager/api/utils/global_utils.h" #include "paddle/fluid/eager/nan_inf_utils.h" #include "paddle/fluid/imperative/layout_autotune.h" -#include "paddle/fluid/ir/dialect/paddle_dialect/interface/vjp.h" +#include "paddle/fluid/pir/dialect/operator/interface/vjp.h" +#include "paddle/fluid/pir/dialect/operator/trait/custom_vjp.h" #include "paddle/fluid/prim/utils/eager/eager_tensor_operants.h" #include "paddle/fluid/prim/utils/static/static_tensor_operants.h" #include "paddle/fluid/pybind/eager_utils.h" -#include "paddle/ir/core/program.h" #include "paddle/phi/api/ext/op_meta_info.h" #include "paddle/phi/api/include/operants_manager.h" #include "paddle/phi/api/include/tensor_operants.h" #include "paddle/phi/core/flags.h" #include "paddle/phi/kernels/autotune/cache.h" #include "paddle/phi/kernels/autotune/switch_autotune.h" +#include "paddle/pir/core/program.h" #include "pybind11/stl.h" PHI_DECLARE_bool(use_mkldnn); @@ -676,7 +677,7 @@ static void AssertStaticGraphAndDygraphGradMakerNoDiff() { string::join_strings(ops, ','))); } -#ifdef PADDLE_WITH_NCCL +#if defined(PADDLE_WITH_NCCL) || defined(PADDLE_WITH_RCCL) static int GetNCCLVersion() { #if NCCL_VERSION_CODE >= 2304 int ver; @@ -692,19 +693,19 @@ static int GetNCCLVersion() { void BindVjp(pybind11::module *m) { m->def( "call_vjp", - [](ir::Operation &fwd_op, - const std::vector> &out_grads, + [](pir::Operation &fwd_op, + const std::vector> &out_grads, const std::vector> &stop_gradients) { py::list res; - ir::IrContext *ctx = ir::IrContext::Instance(); - ir::OpInfo fwd_op_info = ctx->GetRegisteredOpInfo(fwd_op.name()); + pir::IrContext *ctx = pir::IrContext::Instance(); + pir::OpInfo fwd_op_info = ctx->GetRegisteredOpInfo(fwd_op.name()); auto vjp_interface_impl = fwd_op_info.GetInterfaceImpl(); if (vjp_interface_impl == nullptr) { PADDLE_THROW(phi::errors::InvalidArgument( "The vjp function is not registered in %s op ", fwd_op.name())); } - std::vector> vjp_res = + std::vector> vjp_res = vjp_interface_impl->vjp_(&fwd_op, out_grads, stop_gradients); PADDLE_ENFORCE_EQ( stop_gradients.size(), @@ -743,14 +744,29 @@ void BindVjp(pybind11::module *m) { return res; }); - m->def("has_vjp", [](ir::Operation &fwd_op) { - ir::IrContext *ctx = ir::IrContext::Instance(); - ir::OpInfo fwd_op_info = ctx->GetRegisteredOpInfo(fwd_op.name()); + m->def("has_vjp", [](pir::Operation &fwd_op) { + pir::IrContext *ctx = pir::IrContext::Instance(); + pir::OpInfo fwd_op_info = ctx->GetRegisteredOpInfo(fwd_op.name()); auto vjp_interface_impl = fwd_op_info.GetInterfaceImpl(); if (vjp_interface_impl == nullptr) return false; return true; }); + + m->def( + "has_custom_vjp", + [](pir::Operation &op) -> py::bool_ { + return op.info().HasTrait(); + }, + R"DOC( + Return whether an op has custom vjp rules. + + Args: + op (pir::Operation): op to be checked + + Returns: + out (bool): True means that the op has custom vjp rules, False means it does not. + )DOC"); } PYBIND11_MODULE(libpaddle, m) { BindImperative(&m); @@ -872,7 +888,7 @@ PYBIND11_MODULE(libpaddle, m) { }); #endif -#ifdef PADDLE_WITH_NCCL +#if defined(PADDLE_WITH_NCCL) || defined(PADDLE_WITH_RCCL) m.def("nccl_version", &GetNCCLVersion); #endif @@ -1239,11 +1255,15 @@ All parameter, weight, gradient are variables in Paddle. Examples: .. code-block:: python - # create tensor from a scope and set value to it. - param = scope.var('Param').get_tensor() - param_array = np.full((height, row_numel), 5.0).astype("float32") - param.set(param_array, place) + >>> import paddle + >>> import numpy as np + >>> scope = paddle.static.global_scope() + >>> place = paddle.CPUPlace() + >>> # create tensor from a scope and set value to it. + >>> param = scope.var('Param').get_tensor() + >>> param_array = np.full((10, 12), 5.0).astype("float32") + >>> param.set(param_array, place) )DOC"); g_framework_scope_pytype = reinterpret_cast(_Scope.ptr()); _Scope @@ -1983,7 +2003,7 @@ All parameter, weight, gradient are variables in Paddle. py::init< const std::vector> &, const std::unordered_map> &>(), + std::shared_ptr<::pir::Program>> &>(), py::arg("job_list"), py::arg("type_to_ir_program")) .def("job_list", &framework::interpreter::Plan::JobList) @@ -2148,9 +2168,8 @@ All parameter, weight, gradient are variables in Paddle. Examples: .. code-block:: python - import paddle.base as base - - arr = base.LoDTensorArray() + >>> import paddle + >>> arr = paddle.framework.core.LoDTensorArray() )DOC"); g_framework_lodtensorarray_pytype = reinterpret_cast(pylodtensorarray.ptr()); @@ -2190,15 +2209,15 @@ All parameter, weight, gradient are variables in Paddle. None. Examples: - .. code-block:: python + .. code-block:: python - import paddle.base as base - import numpy as np + >>> import paddle + >>> import numpy as np - arr = base.LoDTensorArray() - t = base.LoDTensor() - t.set(np.ndarray([5, 30]), base.CPUPlace()) - arr.append(t) + >>> arr = paddle.framework.core.LoDTensorArray() + >>> t = paddle.framework.core.LoDTensor() + >>> t.set(np.ndarray([5, 30]), paddle.CPUPlace()) + >>> arr.append(t) )DOC") .def( "_move_to_list", diff --git a/paddle/fluid/pybind/tensor.cc b/paddle/fluid/pybind/tensor.cc index b3edc9575223d..95e217365be3d 100644 --- a/paddle/fluid/pybind/tensor.cc +++ b/paddle/fluid/pybind/tensor.cc @@ -393,11 +393,11 @@ void BindTensor(pybind11::module &m) { // NOLINT Examples: .. code-block:: python - import paddle.base as base - import numpy as np + >>> import paddle + >>> import numpy as np - t = base.Tensor() - t.set(np.ndarray([5, 30]), base.CPUPlace()) + >>> t = paddle.framework.core.Tensor() + >>> t.set(np.ndarray([5, 30]), paddle.CPUPlace()) )DOC") .def( @@ -411,14 +411,15 @@ void BindTensor(pybind11::module &m) { // NOLINT Examples: - .. code-block:: python + .. code-block:: python - import paddle.base as base - import numpy as np + >>> import paddle + >>> import numpy as np - t = base.Tensor() - t.set(np.ndarray([5, 30]), base.CPUPlace()) - print(t.shape()) # [5, 30] + >>> t = paddle.framework.core.Tensor() + >>> t.set(np.ndarray([5, 30]), paddle.CPUPlace()) + >>> print(t.shape()) + [5, 30] )DOC") .def("_to_dlpack", [](phi::DenseTensor &self) { @@ -515,15 +516,16 @@ void BindTensor(pybind11::module &m) { // NOLINT None. Examples: - .. code-block:: python + .. code-block:: python - import paddle.base as base - import numpy as np + >>> import paddle + >>> import numpy as np - t = base.Tensor() - t.set(np.ndarray([5, 30]), base.CPUPlace()) - t.set_lod([[0, 2, 5]]) - print(t.lod()) # [[0, 2, 5]] + >>> t = paddle.framework.core.Tensor() + >>> t.set(np.ndarray([5, 30]), paddle.CPUPlace()) + >>> t.set_lod([[0, 2, 5]]) + >>> print(t.lod()) + [[0, 2, 5]] )DOC") .def( "set_recursive_sequence_lengths", @@ -564,16 +566,18 @@ void BindTensor(pybind11::module &m) { // NOLINT None. Examples: - .. code-block:: python - - import paddle.base as base - import numpy as np - - t = base.Tensor() - t.set(np.ndarray([5, 30]), base.CPUPlace()) - t.set_recursive_sequence_lengths([[2, 3]]) - print(t.recursive_sequence_lengths()) # [[2, 3]] - print(t.lod()) # [[0, 2, 5]] + .. code-block:: python + + >>> import paddle + >>> import numpy as np + + >>> t = paddle.framework.core.Tensor() + >>> t.set(np.ndarray([5, 30]), paddle.CPUPlace()) + >>> t.set_recursive_sequence_lengths([[2, 3]]) + >>> print(t.recursive_sequence_lengths()) + [[2, 3]] + >>> print(t.lod()) + [[0, 2, 5]] )DOC") .def( "lod", @@ -592,15 +596,16 @@ void BindTensor(pybind11::module &m) { // NOLINT list[list[int]]: The lod of the Tensor. Examples: - .. code-block:: python + .. code-block:: python - import paddle.base as base - import numpy as np + >>> import paddle + >>> import numpy as np - t = base.Tensor() - t.set(np.ndarray([5, 30]), base.CPUPlace()) - t.set_lod([[0, 2, 5]]) - print(t.lod()) # [[0, 2, 5]] + >>> t = paddle.framework.core.Tensor() + >>> t.set(np.ndarray([5, 30]), paddle.CPUPlace()) + >>> t.set_lod([[0, 2, 5]]) + >>> print(t.lod()) + [[0, 2, 5]] )DOC") // Set above comments of set_lod. .def( @@ -621,15 +626,16 @@ void BindTensor(pybind11::module &m) { // NOLINT list[list[int]]: The recursive sequence lengths. Examples: - .. code-block:: python + .. code-block:: python - import paddle.base as base - import numpy as np + >>> import paddle + >>> import numpy as np - t = base.Tensor() - t.set(np.ndarray([5, 30]), base.CPUPlace()) - t.set_recursive_sequence_lengths([[2, 3]]) - print(t.recursive_sequence_lengths()) # [[2, 3]] + >>> t = paddle.framework.core.Tensor() + >>> t.set(np.ndarray([5, 30]), paddle.CPUPlace()) + >>> t.set_recursive_sequence_lengths([[2, 3]]) + >>> print(t.recursive_sequence_lengths()) + [[2, 3]] )DOC") .def( "has_valid_recursive_sequence_lengths", @@ -645,15 +651,16 @@ void BindTensor(pybind11::module &m) { // NOLINT bool: Whether the LoD is valid. Examples: - .. code-block:: python + .. code-block:: python - import paddle.base as base - import numpy as np + >>> import paddle + >>> import numpy as np - t = base.Tensor() - t.set(np.ndarray([5, 30]), base.CPUPlace()) - t.set_recursive_sequence_lengths([[2, 3]]) - print(t.has_valid_recursive_sequence_lengths()) # True + >>> t = paddle.framework.core.Tensor() + >>> t.set(np.ndarray([5, 30]), paddle.CPUPlace()) + >>> t.set_recursive_sequence_lengths([[2, 3]]) + >>> print(t.has_valid_recursive_sequence_lengths()) + True )DOC") .def("_as_type", [](const phi::DenseTensor &self, @@ -773,12 +780,12 @@ void BindTensor(pybind11::module &m) { // NOLINT tensor dims, lod information, device index. Examples: - .. code-block:: python + .. code-block:: python - import paddle - tensor = paddle.ones([3,3]) - metainfo = tensor.value().get_tensor()._share_cuda() + >>> import paddle + >>> tensor = paddle.ones([3,3]) + >>> metainfo = tensor.value().get_tensor()._share_cuda() )DOC") .def("_new_shared_cuda", [](py::tuple t) { @@ -819,13 +826,13 @@ void BindTensor(pybind11::module &m) { // NOLINT tensor dims, lod information, device index. Examples: - .. code-block:: python + .. code-block:: python - import paddle - tensor = paddle.ones([3,3]) - metainfo = tensor.value().get_tensor()._share_cuda() - tensor_from_shared = paddle.to_tensor(paddle.base.core.LoDTensor._new_shared_cuda(metainfo)) + >>> import paddle + >>> tensor = paddle.ones([3,3]) + >>> metainfo = tensor.value().get_tensor()._share_cuda() + >>> tensor_from_shared = paddle.to_tensor(paddle.base.core.LoDTensor._new_shared_cuda(metainfo)) )DOC") #endif .def("_share_filename", @@ -896,12 +903,12 @@ void BindTensor(pybind11::module &m) { // NOLINT tensor dims and lod imformation. Examples: - .. code-block:: python + .. code-block:: python - import paddle - tensor = paddle.ones([3,3]) - metainfo = tensor.value().get_tensor()._share_filename() + >>> import paddle + >>> tensor = paddle.ones([3,3]) + >>> metainfo = tensor.value().get_tensor()._share_filename() )DOC") .def("_new_shared_filename", [](py::tuple t) { // __setstate__ @@ -940,13 +947,13 @@ void BindTensor(pybind11::module &m) { // NOLINT tensor dims and lod information. Examples: - .. code-block:: python + .. code-block:: python - import paddle - tensor = paddle.ones([3,3]) - metainfo = tensor.value().get_tensor()._share_filename() - tensor_from_shared = paddle.to_tensor(paddle.base.core.LoDTensor._new_shared_filename(metainfo)) + >>> import paddle + >>> tensor = paddle.ones([3,3]) + >>> metainfo = tensor.value().get_tensor()._share_filename() + >>> tensor_from_shared = paddle.to_tensor(paddle.base.core.LoDTensor._new_shared_filename(metainfo)) )DOC") .def("_shared_incref", [](phi::DenseTensor &self) { diff --git a/paddle/ir/core/CMakeLists.txt b/paddle/ir/core/CMakeLists.txt deleted file mode 100644 index 138b102fcbd89..0000000000000 --- a/paddle/ir/core/CMakeLists.txt +++ /dev/null @@ -1,10 +0,0 @@ -set(NEWIR_SOURCE_DIR "${PADDLE_SOURCE_DIR}/paddle/ir") -set(NEWIR_BINARY_DIR "${PADDLE_BINARY_DIR}/paddle/ir") - -file(GLOB IR_SRCS "*.cc") - -file(GLOB IR_PARSER_SRCS "parser/*.cc") - -list(APPEND IR_SRCS ${IR_PARSER_SRCS}) - -ir_library(ir_core SRCS ${IR_SRCS} DEPS ddim) diff --git a/paddle/ir/core/op_base.h b/paddle/ir/core/op_base.h deleted file mode 100644 index 0a491795d4eed..0000000000000 --- a/paddle/ir/core/op_base.h +++ /dev/null @@ -1,249 +0,0 @@ -// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#pragma once -#include - -#include "paddle/ir/core/enforce.h" -#include "paddle/ir/core/operation.h" -#include "paddle/ir/core/utils.h" - -namespace ir { - -class IR_API InterfaceValue { - public: - template - static InterfaceValue get() { - InterfaceValue val; - val.type_id_ = TypeId::get(); - val.model_ = malloc(sizeof(typename T::template Model)); - if (val.model_ == nullptr) { - throw("Alloc memory for interface failed."); - } - static_assert(std::is_trivially_destructible< - typename T::template Model>::value, - "interface models must be trivially destructible"); - new (val.model_) typename T::template Model(); - return val; - } - TypeId type_id() const { return type_id_; } - void *model() const { return model_; } - - InterfaceValue() = default; - explicit InterfaceValue(TypeId type_id) : type_id_(type_id) {} - InterfaceValue(const InterfaceValue &) = delete; - InterfaceValue(InterfaceValue &&) noexcept; - InterfaceValue &operator=(const InterfaceValue &) = delete; - InterfaceValue &operator=(InterfaceValue &&) noexcept; - ~InterfaceValue(); - void swap(InterfaceValue &&val) { - using std::swap; - swap(type_id_, val.type_id_); - swap(model_, val.model_); - } - - /// - /// \brief Comparison operations. - /// - inline bool operator<(const InterfaceValue &other) const { - return type_id_ < other.type_id_; - } - - private: - TypeId type_id_; - void *model_{nullptr}; -}; - -class IR_API OpBase { - public: - explicit OpBase(Operation *operation = nullptr) : operation_(operation) {} - - Operation *operation() const { - IR_ENFORCE(operation_, "Can't use operation() in a null op."); - return operation_; - } - - explicit operator bool() const { return operation_ != nullptr; } - - operator Operation *() const { return operation(); } - - Operation *operator->() const { return operation(); } - - IrContext *ir_context() const { return operation()->ir_context(); } - - uint32_t num_results() const { return operation()->num_results(); } - - uint32_t num_operands() const { return operation()->num_operands(); } - - const AttributeMap &attributes() const { return operation()->attributes(); } - - Value operand_source(uint32_t index) const { - return operation()->operand_source(index); - } - - OpResult result(uint32_t index) const { return operation()->result(index); } - - ir::Attribute attribute(const std::string &name) { - return operation()->attribute(name); - } - - template - T attribute(const std::string &name) { - return operation()->attribute(name); - } - - private: - Operation *operation_; // Not owned -}; - -/// -/// \brief OpTrait -/// -template -class OpTraitBase : public OpBase { - public: - explicit OpTraitBase(Operation *op) : OpBase(op) {} - - static TypeId GetTraitId() { return TypeId::get(); } - - static ConcreteTrait dyn_cast(Operation *op) { - if (op && op->HasTrait()) { - return ConcreteTrait(op); - } - return ConcreteTrait(nullptr); - } -}; - -/// -/// \brief OpInterface -/// -template -class OpInterfaceBase : public OpBase { - public: - explicit OpInterfaceBase(Operation *op) : OpBase(op) {} - - static TypeId GetInterfaceId() { return TypeId::get(); } - - static ConcreteInterface dyn_cast(Operation *op) { - if (op && op->HasInterface()) { - return ConcreteInterface( - op, op->info().GetInterfaceImpl()); - } - return ConcreteInterface(nullptr, nullptr); - } -}; - -template -class ConstructInterfacesOrTraits { - public: - /// Construct method for interfaces. - static InterfaceValue *interface(InterfaceValue *p_interface) { - (void)std::initializer_list{ - 0, (PlacementConstrctInterface(p_interface), 0)...}; - return p_interface; - } - - /// Construct method for traits. - static TypeId *trait(TypeId *p_trait) { - (void)std::initializer_list{ - 0, (PlacementConstrctTrait(p_trait), 0)...}; - return p_trait; - } - - private: - /// Placement new interface. - template - static void PlacementConstrctInterface( - InterfaceValue *&p_interface) { // NOLINT - p_interface->swap(InterfaceValue::get()); - VLOG(6) << "New a interface: id[" - << (p_interface->type_id()).AsOpaquePointer() << "]."; - ++p_interface; - } - - /// Placement new trait. - template - static void PlacementConstrctTrait(ir::TypeId *&p_trait) { // NOLINT - *p_trait = TypeId::get(); - VLOG(6) << "New a trait: id[" << p_trait->AsOpaquePointer() << "]."; - ++p_trait; - } -}; - -/// Specialized for tuple type. -template -class ConstructInterfacesOrTraits> { - public: - /// Construct method for interfaces. - static InterfaceValue *interface(InterfaceValue *p_interface) { - return ConstructInterfacesOrTraits::interface( - p_interface); - } - - /// Construct method for traits. - static TypeId *trait(TypeId *p_trait) { - return ConstructInterfacesOrTraits::trait(p_trait); - } -}; - -template -class Op : public OpBase { - public: - using OpBase::OpBase; - - using TraitList = - typename Filter>::Type; - - using InterfaceList = - typename Filter>::Type; - - static ConcreteOp dyn_cast(Operation *op) { - if (op && op->info().id() == TypeId::get()) { - return ConcreteOp(op); - } - return ConcreteOp(nullptr); - } - - static bool classof(const Operation *op) { - return op && op->info().id() == TypeId::get(); - } - - static std::vector GetInterfaceMap() { - constexpr size_t interfaces_num = std::tuple_size::value; - std::vector interfaces_map(interfaces_num); - ConstructInterfacesOrTraits::interface( - interfaces_map.data()); - return interfaces_map; - } - - static std::vector GetTraitSet() { - constexpr size_t traits_num = std::tuple_size::value; - std::vector trait_set(traits_num); - auto p_first_trait = trait_set.data(); - ConstructInterfacesOrTraits::trait(p_first_trait); - return trait_set; - } - static constexpr bool HasNoDataMembers() { - class EmptyOp : public Op {}; - return sizeof(ConcreteOp) == sizeof(EmptyOp); - } - - static void VerifyInvariants(Operation *op) { - static_assert(HasNoDataMembers(), - "Op class shouldn't define new data members"); - op->dyn_cast().Verify(); - } -}; - -} // namespace ir diff --git a/paddle/ir/core/type.h b/paddle/ir/core/type.h deleted file mode 100644 index f27503b3731f4..0000000000000 --- a/paddle/ir/core/type.h +++ /dev/null @@ -1,133 +0,0 @@ -// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#pragma once - -#include - -#include "paddle/ir/core/cast_utils.h" -#include "paddle/ir/core/type_id.h" - -namespace ir { -class TypeStorage; -class AbstractType; -class IrContext; -class Dialect; -/// -/// \brief Unified interface of the Type class. Derivation of all Type classes -/// only derives interfaces, not members. For example, DenseTensorType, -/// Float32Type, etc. are all derived classes of Type, but no new member -/// variables will be added. -/// -class IR_API Type { - public: - using Storage = TypeStorage; - - Type() = default; - - Type(const Storage *storage) // NOLINT - : storage_(const_cast(storage)) {} - - Type(const Type &other) = default; - - Type &operator=(const Type &other) = default; - - /// - /// \brief Some operators are overloaded. - /// - bool operator==(Type other) const { return storage_ == other.storage_; } - - bool operator!=(Type other) const { return storage_ != other.storage_; } - - explicit operator bool() const { return storage_; } - - bool operator!() const { return storage_ == nullptr; } - - /// - /// \brief Some type attribute acquisition interfaces. - /// - TypeId type_id(); - - const AbstractType &abstract_type(); - - const Storage *storage() const { return storage_; } - - Dialect &dialect() const; - - IrContext *ir_context() const; - - /// - /// \brief Methods for type judgment and cast. - /// - static bool classof(Type) { return true; } - - template - bool isa() const { - return ir::isa(*this); - } - - template - U dyn_cast() const { - return ir::dyn_cast(*this); - } - - void Print(std::ostream &os) const; - - static Type Parse(std::istream &is, IrContext *ctx); - - /// - /// \brief Enable hashing Type. - /// - friend struct std::hash; - - protected: - const Storage *storage_{nullptr}; -}; - -IR_API std::ostream &operator<<(std::ostream &os, Type type); - -} // namespace ir - -/// -/// \brief This class represents the base of a type interface. -/// - -// template -// class TypeInterface : public ir::DialectInterface { -// public: -// using Base = TypeInterface; -// using DialectInterfaceBase = ir::DialectInterface; -// using DialectInterfaceBase::Base; - -// private: -// /// Returns the impl interface instance for the given type. -// static typename InterfaceBase::Concept *getInterfaceFor(Type type) { -// return type.getAbstractType().getInterface(); -// } - -// /// Allow access to 'getInterfaceFor'. -// friend InterfaceBase; -// }; - -namespace std { -/// -/// \brief Enable hashing Type. -/// -template <> -struct hash { - std::size_t operator()(const ir::Type &obj) const { - return std::hash()(obj.storage_); - } -}; -} // namespace std diff --git a/paddle/ir/core/value.cc b/paddle/ir/core/value.cc deleted file mode 100644 index c652ef23a6dde..0000000000000 --- a/paddle/ir/core/value.cc +++ /dev/null @@ -1,300 +0,0 @@ -// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#include "paddle/ir/core/value.h" - -#include - -#include "paddle/ir/core/enforce.h" -#include "paddle/ir/core/operation.h" -#include "paddle/ir/core/value_impl.h" - -#define CHECK_NULL_IMPL(class_name, func_name) \ - IR_ENFORCE(impl_, \ - "impl_ pointer is null when call func:" #func_name \ - " , in class: " #class_name ".") - -#define CHECK_OPOPEREND_NULL_IMPL(func_name) \ - CHECK_NULL_IMPL(OpOpernad, func_name) - -#define CHECK_VALUE_NULL_IMPL(func_name) CHECK_NULL_IMPL(Value, func_name) - -#define CHECK_OPRESULT_NULL_IMPL(func_name) CHECK_NULL_IMPL(OpResult, func_name) -namespace ir { - -// Operand -OpOperand::OpOperand(const detail::OpOperandImpl *impl) - : impl_(const_cast(impl)) {} - -OpOperand &OpOperand::operator=(const OpOperand &rhs) { - if (this == &rhs) return *this; - impl_ = rhs.impl_; - return *this; -} - -OpOperand &OpOperand::operator=(const detail::OpOperandImpl *impl) { - if (this->impl_ == impl) return *this; - impl_ = const_cast(impl); - return *this; -} -OpOperand::operator bool() const { return impl_ && impl_->source(); } - -OpOperand OpOperand::next_use() const { - CHECK_OPOPEREND_NULL_IMPL(next_use); - return impl_->next_use(); -} - -Value OpOperand::source() const { - CHECK_OPOPEREND_NULL_IMPL(source); - return impl_->source(); -} - -Type OpOperand::type() const { return source().type(); } - -void OpOperand::set_source(Value value) { - CHECK_OPOPEREND_NULL_IMPL(set_source); - impl_->set_source(value); -} - -Operation *OpOperand::owner() const { - CHECK_OPOPEREND_NULL_IMPL(owner); - return impl_->owner(); -} - -void OpOperand::RemoveFromUdChain() { - CHECK_OPOPEREND_NULL_IMPL(RemoveFromUdChain); - return impl_->RemoveFromUdChain(); -} - -// Value -Value::Value(const detail::ValueImpl *impl) - : impl_(const_cast(impl)) {} - -bool Value::operator==(const Value &other) const { - return impl_ == other.impl_; -} - -bool Value::operator!=(const Value &other) const { - return impl_ != other.impl_; -} - -bool Value::operator!() const { return impl_ == nullptr; } - -Value::operator bool() const { return impl_; } - -ir::Type Value::type() const { - CHECK_VALUE_NULL_IMPL(type); - return impl_->type(); -} - -void Value::set_type(ir::Type type) { - CHECK_VALUE_NULL_IMPL(set_type); - impl_->set_type(type); -} - -Operation *Value::GetDefiningOp() const { - if (auto result = dyn_cast()) return result.owner(); - return nullptr; -} - -std::string Value::PrintUdChain() { - CHECK_VALUE_NULL_IMPL(PrintUdChain); - return impl()->PrintUdChain(); -} - -Value::UseIterator Value::use_begin() const { - return ir::OpOperand(first_use()); -} - -Value::UseIterator Value::use_end() const { return Value::UseIterator(); } - -OpOperand Value::first_use() const { - CHECK_VALUE_NULL_IMPL(first_use); - return impl_->first_use(); -} - -bool Value::use_empty() const { return !first_use(); } - -bool Value::HasOneUse() const { - CHECK_VALUE_NULL_IMPL(HasOneUse); - return impl_->HasOneUse(); -} - -size_t Value::use_count() const { - size_t count = 0; - for (auto it = use_begin(); it != use_end(); ++it) count++; - return count; -} - -void Value::ReplaceUsesWithIf( - Value new_value, - const std::function &should_replace) const { - for (auto it = use_begin(); it != use_end();) { - if (should_replace(*it)) { - (it++)->set_source(new_value); - } - } -} - -void Value::ReplaceAllUsesWith(Value new_value) const { - for (auto it = use_begin(); it != use_end();) { - (it++)->set_source(new_value); - } -} - -// OpResult -bool OpResult::classof(Value value) { - return value && ir::isa(value.impl()); -} - -Operation *OpResult::owner() const { - CHECK_OPRESULT_NULL_IMPL(owner); - return impl()->owner(); -} - -uint32_t OpResult::GetResultIndex() const { - CHECK_OPRESULT_NULL_IMPL(GetResultIndex); - return impl()->GetResultIndex(); -} - -detail::OpResultImpl *OpResult::impl() const { - return reinterpret_cast(impl_); -} - -bool OpResult::operator==(const OpResult &other) const { - return impl_ == other.impl_; -} - -detail::ValueImpl *OpResult::value_impl() const { - IR_ENFORCE(impl_, "Can't use value_impl() interface while value is null."); - return impl_; -} - -uint32_t OpResult::GetValidInlineIndex(uint32_t index) { - uint32_t max_inline_index = - ir::detail::OpResultImpl::GetMaxInlineResultIndex(); - return index <= max_inline_index ? index : max_inline_index; -} - -// details -namespace detail { -ir::Operation *OpOperandImpl::owner() const { return owner_; } - -ir::detail::OpOperandImpl *OpOperandImpl::next_use() { return next_use_; } - -ir::Value OpOperandImpl::source() const { return source_; } - -void OpOperandImpl::set_source(Value source) { - RemoveFromUdChain(); - if (!source) { - return; - } - source_ = source; - InsertToUdChain(); -} - -OpOperandImpl::OpOperandImpl(ir::Value source, ir::Operation *owner) - : source_(source), owner_(owner) { - if (!source) { - return; - } - InsertToUdChain(); -} - -void OpOperandImpl::InsertToUdChain() { - prev_use_addr_ = source_.impl()->first_use_addr(); - next_use_ = source_.impl()->first_use(); - if (next_use_) { - next_use_->prev_use_addr_ = &next_use_; - } - source_.impl()->set_first_use(this); -} - -void OpOperandImpl::RemoveFromUdChain() { - if (!source_) return; - if (!prev_use_addr_) return; - if (prev_use_addr_ == source_.impl()->first_use_addr()) { - /// NOTE: In ValueImpl, first_use_offseted_by_index_ use lower three bits - /// storage index information, so need to be updated using the set_first_use - /// method here. - source_.impl()->set_first_use(next_use_); - } else { - *prev_use_addr_ = next_use_; - } - if (next_use_) { - next_use_->prev_use_addr_ = prev_use_addr_; - } - next_use_ = nullptr; - prev_use_addr_ = nullptr; - source_ = nullptr; -} - -OpOperandImpl::~OpOperandImpl() { RemoveFromUdChain(); } - -uint32_t ValueImpl::index() const { - uint32_t index = - reinterpret_cast(first_use_offseted_by_index_) & 0x07; - if (index < 6) return index; - return reinterpret_cast(const_cast(this)) - ->GetResultIndex(); -} - -std::string ValueImpl::PrintUdChain() { - std::stringstream result; - result << "Value[" << this << "] -> "; - OpOperandImpl *tmp = first_use(); - if (tmp) { - result << "OpOperand[" << reinterpret_cast(tmp) << "] -> "; - while (tmp->next_use() != nullptr) { - result << "OpOperand[" << reinterpret_cast(tmp->next_use()) - << "] -> "; - tmp = tmp->next_use(); - } - } - result << "nullptr"; - return result.str(); -} - -uint32_t OpResultImpl::GetResultIndex() const { - if (const auto *outline_result = ir::dyn_cast(this)) { - return outline_result->GetResultIndex(); - } - return ir::dyn_cast(this)->GetResultIndex(); -} - -OpResultImpl::~OpResultImpl() { assert(use_empty()); } - -ir::Operation *OpResultImpl::owner() const { - // For inline result, pointer offset index to obtain the address of op. - if (const auto *result = ir::dyn_cast(this)) { - result += result->GetResultIndex() + 1; - return reinterpret_cast( - const_cast(result)); - } - // For outline result, pointer offset outline_index to obtain the address of - // maximum inline result. - const OpOutlineResultImpl *outline_result = - (const OpOutlineResultImpl *)(this); - outline_result += - (outline_result->outline_index_ - GetMaxInlineResultIndex()); - // The offset of the maximum inline result distance op is - // GetMaxInlineResultIndex. - const auto *inline_result = - reinterpret_cast(outline_result); - inline_result += (GetMaxInlineResultIndex() + 1); - return reinterpret_cast( - const_cast(inline_result)); -} -} // namespace detail -} // namespace ir diff --git a/paddle/ir/core/value_impl.h b/paddle/ir/core/value_impl.h deleted file mode 100644 index 14a7b4d63f5d3..0000000000000 --- a/paddle/ir/core/value_impl.h +++ /dev/null @@ -1,210 +0,0 @@ -// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#pragma once - -#include "paddle/ir/core/value.h" - -namespace ir { -static const uint32_t OUTLINE_OP_RESULT_INDEX = 6; - -class Operation; - -namespace detail { -/// -/// \brief OpOperandImpl -/// -class OpOperandImpl { - public: - ir::Operation *owner() const; - - ir::detail::OpOperandImpl *next_use(); - - ir::Value source() const; - - void set_source(Value value); - - /// Remove this op_operand from the current use list. - void RemoveFromUdChain(); - - ~OpOperandImpl(); - - friend ir::Operation; - - private: - OpOperandImpl(ir::Value source, ir::Operation *owner); - - // Insert self to the UD chain holded by source_; - // It is not safe. So set private. - void InsertToUdChain(); - - ir::detail::OpOperandImpl *next_use_ = nullptr; - - ir::detail::OpOperandImpl **prev_use_addr_ = nullptr; - - ir::Value source_; - - ir::Operation *const owner_ = nullptr; -}; - -/// -/// \brief ValueImpl is the base class of all derived Value classes such as -/// OpResultImpl. This class defines all the information and usage interface in -/// the IR Value. Each Value include three attributes: -/// (1) type: ir::Type; (2) UD-chain of value: OpOperandImpl*, first op_operand -/// address with offset of this value; (3) index: the position where the output -/// list of the parent operator. -/// -class alignas(8) ValueImpl { - public: - /// - /// \brief Interface functions of "type_" attribute. - /// - ir::Type type() const { return type_; } - - void set_type(ir::Type type) { type_ = type; } - - /// - /// \brief Interface functions of "first_use_offseted_by_index_" attribute. - /// - uint32_t index() const; - - OpOperandImpl *first_use() const { - return reinterpret_cast( - reinterpret_cast(first_use_offseted_by_index_) & (~0x07)); - } - - void set_first_use(OpOperandImpl *first_use) { - uint32_t offset = index(); - first_use_offseted_by_index_ = reinterpret_cast( - reinterpret_cast(first_use) + offset); - VLOG(4) << "The index of this value is " << offset - << ". Offset and set first use: " << first_use << " -> " - << first_use_offseted_by_index_ << "."; - } - - OpOperandImpl **first_use_addr() { return &first_use_offseted_by_index_; } - - bool use_empty() const { return first_use() == nullptr; } - - bool HasOneUse() const { - return (first_use() != nullptr) && (first_use()->next_use() == nullptr); - } - - std::string PrintUdChain(); - - protected: - /// - /// \brief Only can be constructed by derived classes such as OpResultImpl. - /// - explicit ValueImpl(ir::Type type, uint32_t index) { - if (index > OUTLINE_OP_RESULT_INDEX) { - throw("The value of index must not exceed 6"); - } - type_ = type; - first_use_offseted_by_index_ = reinterpret_cast( - reinterpret_cast(nullptr) + index); - VLOG(4) << "Construct a ValueImpl whose's index is " << index - << ". The offset first_use address is: " - << first_use_offseted_by_index_; - } - - /// - /// \brief Attribute1: Type of value. - /// - ir::Type type_; - - /// - /// \brief Attribute2/3: Record the UD-chain of value and index. - /// NOTE: The members of the OpOperandImpl include four pointers, so this - /// class is 8-byte aligned, and the lower 3 bits of its address are 0, so the - /// index can be stored in these 3 bits, stipulate: - /// (1) index = 0~5: represent positions 0 to 5 inline - /// output(OpInlineResultImpl); (2) index = 6: represent the position >=6 - /// outline output(OpOutlineResultImpl); (3) index = 7 is reserved. - /// - OpOperandImpl *first_use_offseted_by_index_ = nullptr; -}; - -/// -/// \brief OpResultImpl is the implementation of an operation result. -/// -class alignas(8) OpResultImpl : public ValueImpl { - public: - using ValueImpl::ValueImpl; - - static bool classof(const ValueImpl &value) { return true; } - - /// - /// \brief Get the parent operation of this result.(op_ptr = value_ptr + - /// index) - /// - ir::Operation *owner() const; - - /// - /// \brief Get the result index of the operation result. - /// - uint32_t GetResultIndex() const; - - /// - /// \brief Get the maximum number of results that can be stored inline. - /// - static uint32_t GetMaxInlineResultIndex() { - return OUTLINE_OP_RESULT_INDEX - 1; - } - - ~OpResultImpl(); -}; - -/// -/// \brief OpInlineResultImpl is the implementation of an operation result whose -/// index <= 5. -/// -class OpInlineResultImpl : public OpResultImpl { - public: - OpInlineResultImpl(ir::Type type, uint32_t result_index) - : OpResultImpl(type, result_index) { - if (result_index > GetMaxInlineResultIndex()) { - throw("Inline result index should not exceed MaxInlineResultIndex(5)"); - } - } - - static bool classof(const OpResultImpl &value) { - return value.index() < OUTLINE_OP_RESULT_INDEX; - } - - uint32_t GetResultIndex() const { return index(); } -}; - -/// -/// \brief OpOutlineResultImpl is the implementation of an operation result -/// whose index > 5. -/// -class OpOutlineResultImpl : public OpResultImpl { - public: - OpOutlineResultImpl(ir::Type type, uint32_t outline_index) - : OpResultImpl(type, OUTLINE_OP_RESULT_INDEX), - outline_index_(outline_index) {} - - static bool classof(const OpResultImpl &value) { - return value.index() >= OUTLINE_OP_RESULT_INDEX; - } - - uint32_t GetResultIndex() const { return outline_index_; } - - uint32_t outline_index_; -}; - -} // namespace detail -} // namespace ir diff --git a/paddle/ir/dialect/control_flow/CMakeLists.txt b/paddle/ir/dialect/control_flow/CMakeLists.txt deleted file mode 100644 index 5a693ba156ccd..0000000000000 --- a/paddle/ir/dialect/control_flow/CMakeLists.txt +++ /dev/null @@ -1,2 +0,0 @@ -file(GLOB_RECURSE CONTROL_FLOW_SRCS "*.cc") -ir_library(ir_control_flow SRCS ${CONTROL_FLOW_SRCS} DEPS ir_core) diff --git a/paddle/ir/dialect/shape/CMakeLists.txt b/paddle/ir/dialect/shape/CMakeLists.txt deleted file mode 100644 index 62d7c0d42c85c..0000000000000 --- a/paddle/ir/dialect/shape/CMakeLists.txt +++ /dev/null @@ -1,2 +0,0 @@ -file(GLOB_RECURSE SHAPE_SRCS "*.cc") -ir_library(ir_shape SRCS ${SHAPE_SRCS} DEPS ir_core) diff --git a/paddle/ir/dialect/shape/ir/shape_op.cc b/paddle/ir/dialect/shape/ir/shape_op.cc deleted file mode 100644 index 776503ea269e3..0000000000000 --- a/paddle/ir/dialect/shape/ir/shape_op.cc +++ /dev/null @@ -1,198 +0,0 @@ -// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#include "paddle/ir/dialect/shape/ir/shape_op.h" -#include "paddle/ir/core/builtin_attribute.h" -#include "paddle/ir/core/builtin_type.h" - -namespace ir { -namespace dialect { - -const char *SymbolicDim::attributes_name[attributes_num] = {"knownNegativeOne", - "knownNonNegative", - "knownNonSizeOne", - "knownNonSizeZero", - "sym_name", - "value"}; // NOLINT - -void SymbolicDim::Build( - Builder &builder, - OperationArgument &argument, - const std::string &sym_name, - int64_t value, // TODO(zhangbo) value = ShapedType::kDynamic - bool knownNonNegative, - bool knownNegativeOne, - bool knownNonSizeOne, - bool knownNonSizeZero) { - ir::Attribute attr_sym_name = - ir::StrAttribute::get(ir::IrContext::Instance(), sym_name); - argument.AddAttribute("sym_name", attr_sym_name); - ir::Attribute attr_value = - ir::Int64Attribute::get(ir::IrContext::Instance(), value); - argument.AddAttribute("value", attr_value); - ir::Attribute attr_knownNonNegative = - ir::BoolAttribute::get(ir::IrContext::Instance(), knownNonNegative); - argument.AddAttribute("knownNonNegative", attr_knownNonNegative); - ir::Attribute attr_knownNegativeOne = - ir::BoolAttribute::get(ir::IrContext::Instance(), knownNegativeOne); - argument.AddAttribute("knownNegativeOne", attr_knownNegativeOne); - ir::Attribute attr_knownNonSizeOne = - ir::BoolAttribute::get(ir::IrContext::Instance(), knownNonSizeOne); - argument.AddAttribute("knownNonSizeOne", attr_knownNonSizeOne); - ir::Attribute attr_knownNonSizeZero = - ir::BoolAttribute::get(ir::IrContext::Instance(), knownNonSizeZero); - argument.AddAttribute("knownNonSizeZero", attr_knownNonSizeZero); -} - -const std::string SymbolicDim::getSymName() { - return attribute("sym_name").AsString(); -} -int64_t SymbolicDim::getValue() { - return attribute("value").data(); -} -bool SymbolicDim::getKnownNonNegative() { - return attribute("knownNonNegative").data(); -} -bool SymbolicDim::getKnownNegativeOne() { - return attribute("knownNegativeOne").data(); -} -bool SymbolicDim::getKnownNonSizeOne() { - return attribute("knownNonSizeOne").data(); -} -bool SymbolicDim::getKnownNonSizeZero() { - return attribute("knownNonSizeZero").data(); -} - -void SymbolicDim::updateSymName(std::string attrValue) { - operation()->set_attribute( - "sym_name", ir::StrAttribute::get(ir::IrContext::Instance(), attrValue)); -} -void SymbolicDim::updateValue(int64_t attrValue) { - operation()->set_attribute( - "value", ir::Int64Attribute::get(ir::IrContext::Instance(), attrValue)); -} - -void SymbolicDim::updateKnownNonNegative(bool attrValue) { - operation()->set_attribute( - "knownNonNegative", - ir::BoolAttribute::get(ir::IrContext::Instance(), attrValue)); -} -void SymbolicDim::updateKnownNegativeOne(bool attrValue) { - operation()->set_attribute( - "knownNegativeOne", - ir::BoolAttribute::get(ir::IrContext::Instance(), attrValue)); -} -void SymbolicDim::updateKnownNonSizeOne(bool attrValue) { - operation()->set_attribute( - "knownNonSizeOne", - ir::BoolAttribute::get(ir::IrContext::Instance(), attrValue)); -} -void SymbolicDim::updateKnownNonSizeZero(bool attrValue) { - operation()->set_attribute( - "knownNonSizeZero", - ir::BoolAttribute::get(ir::IrContext::Instance(), attrValue)); -} - -bool SymbolicDim::isDynamic() { - return getValue() == -100000; -} // TODO(zhangbo): getValue() == ShapedType::kDynamic; - -bool SymbolicDim::merge(SymbolicDim other) { - if (!isDynamic() && !other.isDynamic() && getValue() != other.getValue()) - return false; - if (isDynamic() && !other.isDynamic()) updateValue(other.getValue()); - if (!isDynamic() && other.isDynamic()) other.updateValue(getValue()); - - bool knownNonNegativeFlag = - getKnownNonNegative() || other.getKnownNonNegative(); - bool knownNegativeOneFlag = - getKnownNegativeOne() || other.getKnownNegativeOne(); - bool knownNonSizeOneFlag = getKnownNonSizeOne() || - other.getKnownNonSizeOne() || knownNegativeOneFlag; - bool knownNonSizeZeroFlag = getKnownNonSizeZero() || - other.getKnownNonSizeZero() || - knownNegativeOneFlag; - - if (knownNonNegativeFlag && knownNegativeOneFlag) return false; - - updateKnownNonSizeZero(knownNonSizeZeroFlag); - updateKnownNonSizeOne(knownNonSizeOneFlag); - updateKnownNegativeOne(knownNegativeOneFlag); - updateKnownNonNegative(knownNonNegativeFlag); - - return true; -} - -const char *DimOp::attributes_name[attributes_num] = {"name"}; // NOLINT - -void DimOp::Build(Builder &builder, - OperationArgument &argument, - const std::string &name) { - ir::Attribute attr_name = - ir::StrAttribute::get(ir::IrContext::Instance(), name); - argument.AddAttribute("name", attr_name); - argument.output_types.emplace_back( - ir::IndexType::get(ir::IrContext::Instance())); -} - -const std::string DimOp::getName() { - return attribute("name").AsString(); -} - -void DimOp::setName(std::string attrName) { - operation()->set_attribute( - "name", ir::StrAttribute::get(ir::IrContext::Instance(), attrName)); -} - -const char *TieProductEqualOp::attributes_name[attributes_num] = { - "lhs_len", "rhs_len"}; // NOLINT - -void TieProductEqualOp::Build(Builder &builder, - OperationArgument &argument, - int64_t lhs_len, - int64_t rhs_len, - const std::vector &inputs) { - ir::Attribute attr_lhs_len = - ir::Int64Attribute::get(ir::IrContext::Instance(), lhs_len); - argument.AddAttribute("lhs_len", attr_lhs_len); - ir::Attribute attr_rhs_len = - ir::Int64Attribute::get(ir::IrContext::Instance(), rhs_len); - argument.AddAttribute("rhs_len", attr_rhs_len); - argument.inputs = inputs; -} - -std::vector TieProductEqualOp::getLhs() { - int64_t lhs_len = attribute("lhs_len").data(); - std::vector res; - for (uint32_t idx = 0; idx < lhs_len; idx++) { - res.push_back(operand_source(idx)); - } - return res; -} -std::vector TieProductEqualOp::getRhs() { - int64_t lhs_len = attribute("lhs_len").data(); - int64_t rhs_len = attribute("rhs_len").data(); - std::vector res; - for (uint32_t idx = 0; idx < rhs_len; idx++) { - res.push_back(operand_source(lhs_len + idx)); - } - return res; -} - -} // namespace dialect -} // namespace ir - -IR_DEFINE_EXPLICIT_TYPE_ID(ir::dialect::SymbolicDim) -IR_DEFINE_EXPLICIT_TYPE_ID(ir::dialect::DimOp) -IR_DEFINE_EXPLICIT_TYPE_ID(ir::dialect::TieProductEqualOp) diff --git a/paddle/ir/pass/CMakeLists.txt b/paddle/ir/pass/CMakeLists.txt deleted file mode 100644 index b4a1d99ab5fcd..0000000000000 --- a/paddle/ir/pass/CMakeLists.txt +++ /dev/null @@ -1,3 +0,0 @@ -file(GLOB NEW_PASS_SRCS "*.cc") - -ir_library(ir_pass SRCS ${NEW_PASS_SRCS} DEPS ir_core) diff --git a/paddle/ir/pattern_rewrite/CMakeLists.txt b/paddle/ir/pattern_rewrite/CMakeLists.txt deleted file mode 100644 index e99611a4ca050..0000000000000 --- a/paddle/ir/pattern_rewrite/CMakeLists.txt +++ /dev/null @@ -1,3 +0,0 @@ -file(GLOB PATTERN_SRCS "*.cc") - -ir_library(ir_pattern_rewrite SRCS ${PATTERN_SRCS} DEPS ir_core) diff --git a/paddle/phi/api/lib/api_gen_utils.cc b/paddle/phi/api/lib/api_gen_utils.cc index c7501494b1e71..887f6b2fb0d24 100644 --- a/paddle/phi/api/lib/api_gen_utils.cc +++ b/paddle/phi/api/lib/api_gen_utils.cc @@ -590,5 +590,28 @@ std::vector SetKernelDistOutput( return results; } +std::vector SetKernelDistInplaceOutput( + size_t out_size, std::vector* out) { + std::vector results(out->size(), nullptr); + for (size_t i = 0; i < out->size(); ++i) { + results[i] = + static_cast(out->at(i).impl().get()); + } + return results; +} + +std::vector SetKernelDistInplaceOptionalOutput( + size_t out_size, paddle::optional> out) { + std::vector results; + if (out) { + results = std::vector(out->size(), nullptr); + for (size_t i = 0; i < out->size(); ++i) { + results[i] = + static_cast(out->at(i).impl().get()); + } + } + return results; +} + } // namespace experimental } // namespace paddle diff --git a/paddle/phi/api/lib/api_gen_utils.h b/paddle/phi/api/lib/api_gen_utils.h index d0281dfc68184..b13688a2ffb49 100644 --- a/paddle/phi/api/lib/api_gen_utils.h +++ b/paddle/phi/api/lib/api_gen_utils.h @@ -150,5 +150,11 @@ std::vector SetKernelDistOutput( std::vector SetKernelDistOutput( size_t out_size, std::vector* out); +std::vector SetKernelDistInplaceOutput( + size_t out_size, std::vector* out); + +std::vector SetKernelDistInplaceOptionalOutput( + size_t out_size, paddle::optional> out); + } // namespace experimental } // namespace paddle diff --git a/paddle/phi/api/lib/data_transform.cc b/paddle/phi/api/lib/data_transform.cc index 7515ff917f10e..e2bb35948d537 100644 --- a/paddle/phi/api/lib/data_transform.cc +++ b/paddle/phi/api/lib/data_transform.cc @@ -672,7 +672,7 @@ PrepareDataForDistTensor(const std::vector& input, const TransformFlag& transform_flag, bool is_stride_kernel) { std::vector> out; - for (auto x : input) { + for (auto& x : input) { const auto& tensor_in = x.impl(); if (tensor_in) { phi::distributed::DistTensor* dist_tensor = @@ -691,16 +691,16 @@ PrepareDataForDistTensor(const std::vector& input, dense_tensor.meta().is_contiguous()))) { out.push_back( std::static_pointer_cast(tensor_in)); - continue; + } else { + phi::DenseTensor trans_in_tensor = TransformData( + dense_tensor, target_args_def, transform_flag, is_stride_kernel); + // TODO(GhostScreaming): The global meta in DistTensor is not changed, + // but the local meta in DenseTensor maybe changed, such as layout + // change(NCHW->NHWC), so the new DistTensor's meta maybe not unified. + VLOG(6) << "PrepareDataForDistTensor return transformed dist tensor"; + out.push_back(std::make_shared( + trans_in_tensor, dist_tensor->dist_attr())); } - phi::DenseTensor trans_in_tensor = TransformData( - dense_tensor, target_args_def, transform_flag, is_stride_kernel); - // TODO(GhostScreaming): The global meta in DistTensor is not changed, - // but the local meta in DenseTensor maybe changed, such as layout - // change(NCHW->NHWC), so the new DistTensor's meta maybe not unified. - VLOG(6) << "PrepareDataForDistTensor return transformed dist tensor"; - out.push_back(std::make_shared( - trans_in_tensor, dist_tensor->dist_attr())); } else { out.push_back(nullptr); } @@ -708,5 +708,29 @@ PrepareDataForDistTensor(const std::vector& input, return out; } +paddle::optional PrepareDataForDistTensor( + const paddle::optional& input, + const phi::TensorArgDef& target_args_def, + const TransformFlag& transform_flag, + bool is_stride_kernel) { + if (input) { + return {*PrepareDataForDistTensor( + *input, target_args_def, transform_flag, is_stride_kernel)}; + } + return paddle::none; +} + +paddle::optional>> +PrepareDataForDistTensor(const paddle::optional>& input, + const phi::TensorArgDef& target_args_def, + const TransformFlag& transform_flag, + bool is_stride_kernel) { + if (input) { + return PrepareDataForDistTensor( + *input, target_args_def, transform_flag, is_stride_kernel); + } + return paddle::none; +} + } // namespace experimental } // namespace paddle diff --git a/paddle/phi/api/lib/data_transform.h b/paddle/phi/api/lib/data_transform.h index 3ac1b94f144ba..1e6cca8bcf5fd 100644 --- a/paddle/phi/api/lib/data_transform.h +++ b/paddle/phi/api/lib/data_transform.h @@ -198,5 +198,17 @@ PrepareDataForDistTensor(const std::vector& input, const TransformFlag& transform_flag, bool is_stride_kernel); +paddle::optional PrepareDataForDistTensor( + const paddle::optional& input, + const phi::TensorArgDef& target_args_def, + const TransformFlag& transform_flag, + bool is_stride_kernel); + +paddle::optional>> +PrepareDataForDistTensor(const paddle::optional>& input, + const phi::TensorArgDef& target_args_def, + const TransformFlag& transform_flag, + bool is_stride_kernel); + } // namespace experimental } // namespace paddle diff --git a/paddle/phi/api/yaml/generator/dist_api_gen.py b/paddle/phi/api/yaml/generator/dist_api_gen.py index ed671ecdfebd6..9c17d51a3b407 100644 --- a/paddle/phi/api/yaml/generator/dist_api_gen.py +++ b/paddle/phi/api/yaml/generator/dist_api_gen.py @@ -93,19 +93,23 @@ VECTOR_OUT_CREATION_TEMPLATE = """ auto dist_out = SetKernelDistOutput({}, &api_output); std::vector dense_out(dist_out.size()); - for (size_t i = 0; i < dist_out.size(); i++) {{ + for (size_t i = 0; i < dist_out.size(); ++i) {{ dense_out[i] = const_cast(&dist_out[i]->value()); }} """ MULTI_VECTOR_OUT_CREATION_TEMPLATE = """ auto dist_out_{out_name} = SetKernelDistOutput({size}, {in_name}); std::vector dense_out_{out_name}(dist_out_{out_name}.size()); - for (size_t i = 0; i < dist_out_{out_name}.size(); i++) {{ + for (size_t i = 0; i < dist_out_{out_name}.size(); ++i) {{ dense_out_{out_name}[i] = const_cast(&dist_out_{out_name}[i]->value()); }} """ -# TODO(GhostScreaming): support tuple output later -TUPLE_OUT_CREATION_TEMPLATE = """ +MULTI_VECTOR_INPLACE_AND_OPTIONAL_OUT_CREATION_TEMPLATE = """ + auto dist_out_{out_name} = {out_func}({size}, {in_name}); + std::vector dense_out_{out_name}(dist_out_{out_name}.size()); + for (size_t i = 0; i < dist_out_{out_name}.size(); ++i) {{ + dense_out_{out_name}[i] = dist_out_{out_name}[i] ? const_cast(&dist_out_{out_name}[i]->value()) : nullptr; + }} """ # 3. Infer Global Shape @@ -119,12 +123,28 @@ {name}_meta_vec.emplace_back(MakeMetaTensor(*tmp.impl())); }} std::vector {name}_meta_ptr_vec({name}_meta_vec.size()); - for (size_t i=0; i<{name}_meta_ptr_vec.size(); i++) {{ + for (size_t i=0; i < {name}_meta_ptr_vec.size(); ++i) {{ {name}_meta_ptr_vec[i] = &{name}_meta_vec[i]; }} """ -# TODO(GhostScreaming): support optional args later -OPTIONAL_GLOBAL_VECTOR_META_IN_TEMPLATE = """ +OPTIONAL_GLOBAL_SINGLE_META_IN_TEMPLATE = """meta_dist_{}, """ +OPTIONAL_GLOBAL_SINGLE_META_IN_DECL_TEMPLATE = """ + phi::MetaTensor meta_dist_{name} = {name} ? MakeMetaTensor(*(*{name}).impl()) : phi::MetaTensor(); +""" +OPTIONAL_GLOBAL_VECTOR_META_IN_TEMPLATE = """{}_meta_ptr_vec, """ +OPTIONAL_GLOBAL_VECTOR_META_IN_DECL_TEMPLATE = """ + std::vector {name}_meta_vec_tmp; + if ({name}) {{ + for (auto tmp : *{name}) {{ + {name}_meta_vec_tmp.emplace_back(MakeMetaTensor(*tmp.impl())); + }} + }} + std::vector {name}_meta_ptr_vec_tmp({name}_meta_vec_tmp.size()); + for (size_t i = 0; i < {name}_meta_ptr_vec_tmp.size(); ++i) {{ + {name}_meta_ptr_vec_tmp[i] = &{name}_meta_vec_tmp[i]; + }} + paddle::optional> {name}_meta_ptr_vec = + {name} ? paddle::make_optional>({name}_meta_ptr_vec_tmp) : paddle::none; """ SINGLE_GLOBAL_META_OUT_DECL_TEMPLATE = """ phi::MetaTensor meta_{}({});""" @@ -134,7 +154,7 @@ {name}_meta_vec.emplace_back(phi::MetaTensor(tmp)); }} std::vector {name}_meta_ptr_vec({name}.size()); - for (size_t i=0; i<{name}_meta_vec.size(); i++) {{ + for (size_t i = 0; i < {name}_meta_vec.size(); ++i) {{ {name}_meta_ptr_vec[i] = &{name}_meta_vec[i]; }} """ @@ -173,10 +193,31 @@ }} std::vector dense_input_{name}_meta_vec = MakeMetaTensor(dense_input_{name}_vec); std::vector dense_input_{name}_meta_ptr_vec(dense_input_{name}_meta_vec.size()); - for (size_t i=0; i input_{name} = dist_input_{name} ? paddle::make_optional(dist_input_{name}->value()) : paddle::none; +""" +OPTIONAL_VECTOR_PREPARE_DATA_TEMPLATE = """ + auto dist_input_{name}_vec = PrepareDataForDistTensor({name}, GetKernelInputArgDef(kernel.InputAt({index}), kernel_backend), {trans_flag}, kernel_result.is_stride_kernel); + std::vector dense_input_{name}_vec; + if ({name}) {{ + for (auto tmp : *dist_input_{name}_vec) {{ + dense_input_{name}_vec.emplace_back(&tmp->value()); + }} + }} + paddle::optional> input_{name}(dense_input_{name}_vec); + std::vector dense_input_{name}_meta_vec = MakeMetaTensor(dense_input_{name}_vec); + std::vector dense_input_{name}_meta_ptr_vec_tmp(dense_input_{name}_meta_vec.size()); + for (size_t i = 0; i < dense_input_{name}_meta_ptr_vec_tmp.size(); ++i) {{ + dense_input_{name}_meta_ptr_vec_tmp[i] = &dense_input_{name}_meta_vec[i]; + }} + paddle::optional> dense_input_{name}_meta_ptr_vec = + {name} ? paddle::make_optional>(dense_input_{name}_meta_ptr_vec_tmp) : paddle::none; +""" INFER_META_SINGLE_INPUT_TEMPLATE = """ auto dist_input_{} = {}.impl(); auto input_{} = &(static_cast(dist_input_{}.get())->value()); @@ -191,16 +232,15 @@ # 7. Infer Local DenseTensor Meta SINGLE_META_IN_TEMPLATE = """MakeMetaTensor(*input_{}), """ -# TODO(GhostScreaming): support optional args later VECTOR_META_IN_TEMPLATE = """dense_input_{}_meta_ptr_vec, """ -OPTIONAL_VECTOR_META_IN_TEMPLATE = """ -""" +OPTIONAL_SINGLE_META_IN_TEMPLATE = """MakeMetaTensor(input_{}), """ +OPTIONAL_VECTOR_META_IN_TEMPLATE = """dense_input_{}_meta_ptr_vec, """ SINGLE_META_OUT_DECL_TEMPLATE = """ phi::MetaTensor meta_{}({});""" VECTOR_META_OUT_DECL_TEMPLATE = """ std::vector {name}_meta_vec = MakeMetaTensor({name}); std::vector {name}_meta_ptr_vec({name}_meta_vec.size()); - for (size_t i=0; i<{name}_meta_vec.size(); i++) {{ + for (size_t i = 0; i < {name}_meta_vec.size(); ++i) {{ {name}_meta_ptr_vec[i] = &{name}_meta_vec[i]; }} """ @@ -221,6 +261,22 @@ auto* kernel_fn = kernel.GetVariadicKernelFn(); (*kernel_fn)({}, {}); """ +# TODO(GhostScreaming): Some operators generate shape info in runtime, +# bincount. As a result, dist_output's global shape is set uncorrectly, +# because it's generated in InferMeta function. A temporally solution is +# use black op list to set DistTensor shape extra. +SINGLE_SET_DIST_OUT_DIMS = """ + dist_out->unsafe_set_dims(dense_out->dims()); +""" +MULTI_SINGLE_SET_DIST_OUT_DIMS = """ + dist_out_{}->unsafe_set_dims(dense_out_{}->dims()); +""" +VECTOR_SET_DIST_OUT_DIMS = """ + for (size_t i = 0; i < dist_out.size(); ++i) {{ + dist_out[i]->unsafe_set_dims(dense_out[i]->dims()); + }} +""" + PREFIX_VECTOR_TENSOR_NAME = "dense_input_" SUFFIX_VECTOR_TENSOR_NAME = "_vec" @@ -236,13 +292,15 @@ # types : [], list of output types # out_size_expr : [], expression for getting size of vector -# TODO(GhostScreaming): Support std::tuple<...> type of input and output later. -skip_op_lists = [ - "check_finite_and_unscale", # std::vector&, const Tensor& -> std::tuple&, Tensor> - "coalesce_tensor", # const std::vector&, DataType, bool, bool, bool, float, bool, int, int, const std::vector&, const std::vector& -> std::tuple, Tensor> - "update_loss_scaling", # std::vector, const Tensor, ... -> std::tuple, Tensor, Tensor, Tensor> - "einsum", - "einsum_grad", # const std::vector&, const std::string& -> std::tuple, std::vector> + +# TODO(GhostScreaming): Black list for operators which infer shape in runtime. +ops_infer_shape_in_runtime = [ + "bincount", + "bicubic_interp", + "bilinear_interp", + "linear_interp", + "nearest_interp", + "trilinear_interp", ] @@ -256,12 +314,15 @@ def init_dist_api_members(self): "const Tensor&": { "dense": self.generate_single_dense_input, }, - "const paddle::optional&": { - "dense": self.generate_single_dense_input, - }, "const std::vector&": { "dense": self.generate_vector_dense_input, }, + "const paddle::optional&": { + "dense": self.generate_optional_single_dense_input, + }, + "const paddle::optional>&": { + "dense": self.generate_optional_vector_dense_input, + }, } self.inplace_flag = False @@ -423,25 +484,28 @@ def generate_output_creation_code(self) -> str: get_out_code = f"&std::get<{i}>(api_output)" if self.is_inplace_and_optional_output(i): get_out_code = f"std::get<{i}>(api_output).get_ptr()" - if out_type == 'std::vector': self.vector_output_size_assertion_check() # Special case for inplace vector and inplace optional - # TODO(chenweihang): support this branch later if self.is_inplace_output(i): - set_out_func = "SetInplaceVectorKernelOutput" + set_out_func = "SetKernelDistInplaceOutput" if self.is_inplace_and_optional_output(i): - set_out_func = ( - "SetInplaceOptionalVectorKernelOutput" - ) + set_out_func = "SetKernelDistInplaceOptionalOutput" get_out_code = f"std::get<{i}>(api_output)" - output_creation_code += ( - MULTI_VECTOR_OUT_CREATION_TEMPLATE.format( + output_creation_code += MULTI_VECTOR_INPLACE_AND_OPTIONAL_OUT_CREATION_TEMPLATE.format( + out_func=set_out_func, out_name=i, size=self.outputs['out_size_expr'][i], in_name=get_out_code, ) - ) + else: + output_creation_code += ( + MULTI_VECTOR_OUT_CREATION_TEMPLATE.format( + out_name=i, + size=self.outputs['out_size_expr'][i], + in_name=get_out_code, + ) + ) else: if self.infer_meta['spmd_rule'] is not None: output_creation_code += ( @@ -496,6 +560,31 @@ def generate_infer_global_shape_code(self) -> str: input_meta_code += ( VECTOR_GLOBAL_META_IN_DECL_TEMPLATE.format(name=param) ) + elif ( + self.inputs['input_info'][param] + == "const paddle::optional&" + ): + input_args_code += ( + OPTIONAL_GLOBAL_SINGLE_META_IN_TEMPLATE.format(param) + ) + input_meta_code += ( + OPTIONAL_GLOBAL_SINGLE_META_IN_DECL_TEMPLATE.format( + name=param + ) + ) + elif ( + self.inputs['input_info'][param] + == "const paddle::optional>&" + ): + input_args_code += ( + OPTIONAL_GLOBAL_VECTOR_META_IN_TEMPLATE.format(param) + ) + input_meta_code += ( + OPTIONAL_GLOBAL_VECTOR_META_IN_DECL_TEMPLATE.format( + name=param + ) + ) + else: raise ValueError( f"{self.api} : Param of infer_spmd error : {self.inputs['input_info'][param]} type is not supported." @@ -517,12 +606,7 @@ def generate_infer_global_shape_code(self) -> str: output_decl_code += VECTOR_GLOBAL_META_OUT_DECL_TEMPLATE.format( name=out_name ) - if len(self.dense_output_args) == 1: - output_args_code += f"{out_name}_meta_ptr_vec, " - else: - output_args_code += ( - f"{out_name} ? {out_name}_meta_ptr_vec : nullptr, " - ) + output_args_code += f"{out_name}_meta_ptr_vec, " else: output_decl_code += SINGLE_GLOBAL_META_OUT_DECL_TEMPLATE.format( out_name, out_name @@ -628,6 +712,46 @@ def generate_vector_dense_input( return input_tensor_code + def generate_optional_single_dense_input( + self, + input_name, + ): + input_tensor_code = "" + trans_flag = self.gene_trans_flag(input_name) + input_names = self.inputs['names'] + attr_names = self.attrs['names'] + kernel_param = self.kernel['param'] + if kernel_param is None: + kernel_param = input_names + attr_names + + input_tensor_code += OPTIONAL_SINGLE_PREPARE_DATA_TEMPLATE.format( + name=input_name, + index=kernel_param.index(input_name), + trans_flag=trans_flag, + ) + + return input_tensor_code + + def generate_optional_vector_dense_input( + self, + input_name, + ): + input_tensor_code = "" + trans_flag = self.gene_trans_flag(input_name) + input_names = self.inputs['names'] + attr_names = self.attrs['names'] + kernel_param = self.kernel['param'] + if kernel_param is None: + kernel_param = input_names + attr_names + + input_tensor_code += OPTIONAL_VECTOR_PREPARE_DATA_TEMPLATE.format( + name=input_name, + index=kernel_param.index(input_name), + trans_flag=trans_flag, + ) + + return input_tensor_code + def generate_prepare_data_code(self) -> str: input_names = self.inputs['names'] attr_names = self.attrs['names'] @@ -703,6 +827,20 @@ def generate_infer_meta_code(self) -> str: == "const std::vector&" ): input_args_code += VECTOR_META_IN_TEMPLATE.format(param) + elif ( + self.inputs['input_info'][param] + == "const paddle::optional&" + ): + input_args_code += OPTIONAL_SINGLE_META_IN_TEMPLATE.format( + param + ) + elif ( + self.inputs['input_info'][param] + == "const paddle::optional>&" + ): + input_args_code += OPTIONAL_VECTOR_META_IN_TEMPLATE.format( + param + ) else: raise ValueError( f"{self.api} : Param of infer_meta error : {self.inputs['input_info'][param]} type is not supported." @@ -724,12 +862,7 @@ def generate_infer_meta_code(self) -> str: output_decl_code += VECTOR_META_OUT_DECL_TEMPLATE.format( name=out_name ) - if len(self.dense_output_args) == 1: - output_args_code += f"{out_name}_meta_ptr_vec, " - else: - output_args_code += ( - f"{out_name} ? {out_name}_meta_ptr_vec : nullptr, " - ) + output_args_code += f"{out_name}_meta_ptr_vec, " else: output_decl_code += SINGLE_META_OUT_DECL_TEMPLATE.format( out_name, out_name @@ -818,11 +951,22 @@ def generate_kernel_call_code(self) -> str: kernel_args_type_list.append(dense_output_trans_map[out_type]) kernel_signature = "void(*)(" + ", ".join(kernel_args_type_list) + ")" - return KERNEL_CALL_TEMPLATE.format( + result = KERNEL_CALL_TEMPLATE.format( kernel_signature, ", ".join(input_args), ", ".join(self.dense_output_args), ) + global ops_infer_shape_in_runtime + if self.kernel['func'][0] in ops_infer_shape_in_runtime: + if len(self.outputs['types']) == 1: + if self.outputs['types'][0] == 'Tensor': + result += SINGLE_SET_DIST_OUT_DIMS + elif self.outputs['types'][0] == 'std::vector': + result += VECTOR_SET_DIST_OUT_DIMS + else: + for i in range(len(self.outputs['types'])): + result += MULTI_SINGLE_SET_DIST_OUT_DIMS.format(i, i) + return result def generate_return_code(self) -> str: return self.gene_return_code() @@ -845,19 +989,17 @@ def generate_auto_paralel_branch(self) -> str: ) def check_argument_whether_support_auto_parallel(self): - global skip_op_lists for name in self.inputs['names']: if self.inputs['input_info'][name] not in [ "const Tensor&", "const std::vector&", + "const paddle::optional&", + "const paddle::optional>&", ]: return False for out_type in self.outputs['types']: if out_type not in ["Tensor", "std::vector"]: return False - - if self.kernel['func'][0] in skip_op_lists: - return False return True # override BaseAPI's method diff --git a/paddle/phi/api/yaml/generator/dist_bw_api_gen.py b/paddle/phi/api/yaml/generator/dist_bw_api_gen.py index 25944e3356966..8f39859882579 100644 --- a/paddle/phi/api/yaml/generator/dist_bw_api_gen.py +++ b/paddle/phi/api/yaml/generator/dist_bw_api_gen.py @@ -48,6 +48,13 @@ dense_out[i] = const_cast(&dist_out[i]->value()); }} """ +VECTOR_OUT_CREATION_TEMPLATE = """ + auto dist_out = SetKernelDistOutput({name}); + std::vector dense_out(dist_out.size()); + for (size_t i = 0; i < dist_out.size(); i++) {{ + dense_out[i] = const_cast(&dist_out[i]->value()); + }} +""" INPLACE_OUT_CREATION_TEMPLATE = """ *{} = {}; """ @@ -69,6 +76,13 @@ auto dist_input_{arg} = PrepareDataForDistTensor({arg}, GetKernelInputArgDef(kernel.InputAt({idx}), kernel_backend), {flag}, kernel_result.is_stride_kernel); auto input_{arg} = &dist_input_{}->value(); """ +MULTI_VECTOR_OUT_CREATION_TEMPLATE = """ + auto dist_out_{i} = SetKernelDistOutput({name}); + std::vector dense_out_{i}(dist_out_{i}.size()); + for (size_t i = 0; i < dist_out_{i}.size(); i++) {{ + dense_out_{i}[i] = const_cast(&dist_out_{i}[i]->value()); + }} +""" class DistBackwardAPI(DistForwardAPI, BackwardAPI): @@ -104,6 +118,12 @@ def generate_output_creation_code(self) -> str: i, self.outputs['names'][i], i, i ) ) + elif out_type == 'std::vector': + output_creation_code += ( + MULTI_VECTOR_OUT_CREATION_TEMPLATE.format( + i=i, name=self.outputs['names'][i] + ) + ) else: self.vector_output_size_assertion_check() else: diff --git a/paddle/phi/api/yaml/legacy_ops.yaml b/paddle/phi/api/yaml/legacy_ops.yaml index 4c151374c6893..a647e02b35ef2 100755 --- a/paddle/phi/api/yaml/legacy_ops.yaml +++ b/paddle/phi/api/yaml/legacy_ops.yaml @@ -123,6 +123,25 @@ backward : batch_norm_grad optional : reserve_space +- op : c_allgather + args : (Tensor x, int ring_id, int nranks, bool use_calc_stream) + output : Tensor(out) + infer_meta : + func : AllGatherInferMeta + param: [x, nranks] + kernel : + func : c_allgather + +- op : c_allreduce_max + args : (Tensor x, int ring_id, bool use_calc_stream, bool use_model_parallel) + output : Tensor(out) + infer_meta : + func : AllReduceInferMeta + param : [x] + kernel : + func : c_allreduce_max + inplace : (x -> out) + - op : c_allreduce_sum args : (Tensor x, int ring_id, bool use_calc_stream, bool use_model_parallel) output : Tensor(out) @@ -173,6 +192,16 @@ func : c_identity inplace : (x -> out) +- op : c_reduce_sum + args : (Tensor x, int ring_id, int root_id, bool use_calc_stream) + output : Tensor(out) + infer_meta : + func : DistReduceInferMeta + param : [x] + kernel : + func : c_reduce_sum + inplace : (x -> out) + - op : c_sync_calc_stream args : (Tensor x) output : Tensor(out) @@ -651,19 +680,11 @@ output : Tensor infer_meta : func : MatmulInferMeta - spmd_rule : MatmulSpmdInferForward + spmd_rule : MatmulInferSpmd kernel : func : matmul backward : matmul_grad -- op : matmul_int8 - args : (Tensor x, Tensor y, bool transpose_x = false, bool transpose_y = false) - output : Tensor - infer_meta : - func : MatmulInt8InferMeta - kernel : - func : matmul_int8 - - op : matrix_rank args : (Tensor x, float tol, bool use_default_tol=true, bool hermitian=false) output : Tensor(out) diff --git a/paddle/phi/api/yaml/op_compat.yaml b/paddle/phi/api/yaml/op_compat.yaml index 495ba53cd7613..9d499c68bef74 100755 --- a/paddle/phi/api/yaml/op_compat.yaml +++ b/paddle/phi/api/yaml/op_compat.yaml @@ -810,6 +810,7 @@ backward : dropout_grad inputs : x : X + seed_tensor : Seed outputs : out : Out mask : Mask @@ -2428,6 +2429,8 @@ out : Out - op : seed + outputs : + out : Out extra : attrs : [bool deterministic = false, str rng_name = "", bool force_cpu = false] @@ -2638,6 +2641,7 @@ out : Out - op : split + backward : split_grad inputs: x : X outputs: @@ -3047,6 +3051,18 @@ yolo_loss : GetYoloLossExpectedKernelType yolo_loss_grad : GetYoloLossExpectedKernelType +- op: c_allgather + inputs : + x : X + outputs : + out: Out + +- op: c_allreduce_max + inputs : + x : X + outputs : + out: Out + - op: c_allreduce_sum inputs : x : X @@ -3065,6 +3081,12 @@ outputs : out: Out +- op: c_reduce_sum + inputs : + x : X + outputs : + out: Out + - op: c_sync_calc_stream inputs : x : X diff --git a/paddle/phi/backends/device_manager.cc b/paddle/phi/backends/device_manager.cc index d95cb12646f41..24ad5087769de 100644 --- a/paddle/phi/backends/device_manager.cc +++ b/paddle/phi/backends/device_manager.cc @@ -30,10 +30,17 @@ namespace phi { void Device::CheckInitialized() { - std::call_once(initialized_, [&]() { this->impl_->InitDevice(dev_id_); }); + std::call_once(initialized_once_flag_, [&]() { + this->impl_->InitDevice(dev_id_); + this->initialized_ = true; + }); } -Device::~Device() { impl_->DeInitDevice(dev_id_); } +Device::~Device() { + if (initialized_) { + impl_->DeInitDevice(dev_id_); + } +} void Device::CreateStream(stream::Stream* stream, const stream::Stream::Priority& priority, diff --git a/paddle/phi/backends/device_manager.h b/paddle/phi/backends/device_manager.h index 62c85aeb52674..58a9e6ebe7ab8 100644 --- a/paddle/phi/backends/device_manager.h +++ b/paddle/phi/backends/device_manager.h @@ -127,7 +127,8 @@ class Device final { private: size_t dev_id_; DeviceInterface* impl_; - std::once_flag initialized_; + std::once_flag initialized_once_flag_; + bool initialized_{false}; }; class DeviceManager { diff --git a/paddle/phi/backends/gpu/cuda/cuda_helper.h b/paddle/phi/backends/gpu/cuda/cuda_helper.h index 32a5d10d6291b..555cc2357b2ab 100644 --- a/paddle/phi/backends/gpu/cuda/cuda_helper.h +++ b/paddle/phi/backends/gpu/cuda/cuda_helper.h @@ -88,6 +88,12 @@ cudaDataType_t ToCudaDataType() { #if CUDA_VERSION >= 11000 } else if (std::is_same::value) { return CUDA_R_16BF; +#endif +#if CUDA_VERSION >= 11060 + } else if (std::is_same::value) { + return CUDA_R_8I; + } else if (std::is_same::value) { + return CUDA_R_32I; #endif } else { PADDLE_THROW(phi::errors::InvalidArgument( diff --git a/paddle/phi/common/complex.h b/paddle/phi/common/complex.h index e0ff7f11ac542..ceb46874238f3 100644 --- a/paddle/phi/common/complex.h +++ b/paddle/phi/common/complex.h @@ -456,6 +456,26 @@ HOSTDEVICE inline complex tan(const complex& a) { #endif } +template +HOSTDEVICE inline complex sinh(const complex& a) { +#if defined(PADDLE_WITH_CUDA_OR_HIP_COMPLEX) && \ + (defined(__CUDA_ARCH__) || defined(__HIPCC__)) + return complex(thrust::sinh(thrust::complex(a))); +#else + return complex(std::sinh(std::complex(a))); +#endif +} + +template +HOSTDEVICE inline complex cosh(const complex& a) { +#if defined(PADDLE_WITH_CUDA_OR_HIP_COMPLEX) && \ + (defined(__CUDA_ARCH__) || defined(__HIPCC__)) + return complex(thrust::cosh(thrust::complex(a))); +#else + return complex(std::cosh(std::complex(a))); +#endif +} + template HOSTDEVICE inline complex tanh(const complex& a) { #if defined(PADDLE_WITH_CUDA_OR_HIP_COMPLEX) && \ @@ -466,6 +486,66 @@ HOSTDEVICE inline complex tanh(const complex& a) { #endif } +template +HOSTDEVICE inline complex asin(const complex& a) { +#if defined(PADDLE_WITH_CUDA_OR_HIP_COMPLEX) && \ + (defined(__CUDA_ARCH__) || defined(__HIPCC__)) + return complex(thrust::asin(thrust::complex(a))); +#else + return complex(std::asin(std::complex(a))); +#endif +} + +template +HOSTDEVICE inline complex acos(const complex& a) { +#if defined(PADDLE_WITH_CUDA_OR_HIP_COMPLEX) && \ + (defined(__CUDA_ARCH__) || defined(__HIPCC__)) + return complex(thrust::acos(thrust::complex(a))); +#else + return complex(std::acos(std::complex(a))); +#endif +} + +template +HOSTDEVICE inline complex atan(const complex& a) { +#if defined(PADDLE_WITH_CUDA_OR_HIP_COMPLEX) && \ + (defined(__CUDA_ARCH__) || defined(__HIPCC__)) + return complex(thrust::atan(thrust::complex(a))); +#else + return complex(std::atan(std::complex(a))); +#endif +} + +template +HOSTDEVICE inline complex asinh(const complex& a) { +#if defined(PADDLE_WITH_CUDA_OR_HIP_COMPLEX) && \ + (defined(__CUDA_ARCH__) || defined(__HIPCC__)) + return complex(thrust::asinh(thrust::complex(a))); +#else + return complex(std::asinh(std::complex(a))); +#endif +} + +template +HOSTDEVICE inline complex acosh(const complex& a) { +#if defined(PADDLE_WITH_CUDA_OR_HIP_COMPLEX) && \ + (defined(__CUDA_ARCH__) || defined(__HIPCC__)) + return complex(thrust::acosh(thrust::complex(a))); +#else + return complex(std::acosh(std::complex(a))); +#endif +} + +template +HOSTDEVICE inline complex atanh(const complex& a) { +#if defined(PADDLE_WITH_CUDA_OR_HIP_COMPLEX) && \ + (defined(__CUDA_ARCH__) || defined(__HIPCC__)) + return complex(thrust::atanh(thrust::complex(a))); +#else + return complex(std::atanh(std::complex(a))); +#endif +} + template HOSTDEVICE inline complex conj(const complex& a) { #if defined(PADDLE_WITH_CUDA_OR_HIP_COMPLEX) && \ diff --git a/paddle/phi/core/dense_tensor.cc b/paddle/phi/core/dense_tensor.cc index 20b1da1efe39d..89f18920ef1fd 100644 --- a/paddle/phi/core/dense_tensor.cc +++ b/paddle/phi/core/dense_tensor.cc @@ -232,6 +232,9 @@ void DenseTensor::set_meta(const DenseTensorMeta& meta) { } else { meta_.strides = meta.strides; } +#ifdef PADDLE_WITH_XPU + meta_.scale_value = meta.scale_value; +#endif } /* @jim19930609: This interface will be further modified until we finalized the diff --git a/paddle/phi/core/dense_tensor_impl.cc b/paddle/phi/core/dense_tensor_impl.cc index ed1944ade402b..4595d11a594f7 100644 --- a/paddle/phi/core/dense_tensor_impl.cc +++ b/paddle/phi/core/dense_tensor_impl.cc @@ -392,6 +392,9 @@ DenseTensor& DenseTensor::ShareDataWith(const DenseTensor& src) { meta_.offset = src.meta_.offset; meta_.use_gpudnn = src.meta_.use_gpudnn; meta_.strides = src.meta_.strides; +#ifdef PADDLE_WITH_XPU + meta_.scale_value = src.meta_.scale_value; +#endif storage_properties_ = std::move(CopyStorageProperties(src.storage_properties_)); #ifdef PADDLE_WITH_DNNL diff --git a/paddle/phi/core/distributed/auto_parallel/dist_attr.cc b/paddle/phi/core/distributed/auto_parallel/dist_attr.cc index fc105915738bb..46e58cc9b373e 100644 --- a/paddle/phi/core/distributed/auto_parallel/dist_attr.cc +++ b/paddle/phi/core/distributed/auto_parallel/dist_attr.cc @@ -102,6 +102,10 @@ void TensorDistAttr::set_partial_status(const std::vector& dims, "Trying to Set dim %d as Partial which is already a Partial dim.", dim)); } + if (std::count(dims_mapping_.begin(), dims_mapping_.end(), dim)) { + PADDLE_THROW(phi::errors::InvalidArgument( + "Trying to Set dim %d as Partial which is a Sharding dim.", dim)); + } partial_status_.emplace(dim, type); } } diff --git a/paddle/phi/core/distributed/auto_parallel/dist_tensor.cc b/paddle/phi/core/distributed/auto_parallel/dist_tensor.cc index 830665670e8ca..b9103a00c9d02 100644 --- a/paddle/phi/core/distributed/auto_parallel/dist_tensor.cc +++ b/paddle/phi/core/distributed/auto_parallel/dist_tensor.cc @@ -14,6 +14,7 @@ #include "paddle/phi/core/distributed/auto_parallel/dist_tensor.h" +#include "glog/logging.h" #include "paddle/phi/backends/context_pool.h" #include "paddle/phi/core/distributed/auto_parallel/reshard_function.h" #include "paddle/phi/core/distributed/auto_parallel/reshard_utils.h" @@ -54,14 +55,11 @@ DistTensor::DistTensor(const phi::DenseTensor& global_value, DistTensor::DistTensor(const DDim& dims, const TensorDistAttr& dist_attr) : dims_(dims), dist_attr_(dist_attr) {} -void DistTensor::set_dims(const DDim& dims) { - PADDLE_ENFORCE_EQ( - this->initialized(), - false, - phi::errors::Unimplemented( - "DistTensor's set_dims method can only be used when the `value` " - "is not initialized (generally used in the InferMeta and " - "InferSPMD stages).")); +void DistTensor::unsafe_set_dims(const DDim& dims) { + if (this->initialized()) { + VLOG(3) << "You try to set an initialized DistTensor's global dims. " + "Make sure you are aware of where you change its dims."; + } dims_ = dims; } diff --git a/paddle/phi/core/distributed/auto_parallel/dist_tensor.h b/paddle/phi/core/distributed/auto_parallel/dist_tensor.h index bc8b98d81a3ff..1289a23b1be8c 100644 --- a/paddle/phi/core/distributed/auto_parallel/dist_tensor.h +++ b/paddle/phi/core/distributed/auto_parallel/dist_tensor.h @@ -56,7 +56,7 @@ class DistTensor final /// \brief Set the global dims of the dist tensor. /// \return void - void set_dims(const DDim& dims); + void unsafe_set_dims(const DDim& dims); /// \brief Returns the dist attr of current dist tensor. /// \return The TensorDistAttr's const reference diff --git a/paddle/phi/core/distributed/auto_parallel/inferspmd_utils.cc b/paddle/phi/core/distributed/auto_parallel/inferspmd_utils.cc index 531727b3ee8d1..a1895b6dfbd79 100644 --- a/paddle/phi/core/distributed/auto_parallel/inferspmd_utils.cc +++ b/paddle/phi/core/distributed/auto_parallel/inferspmd_utils.cc @@ -18,7 +18,18 @@ namespace phi { namespace distributed { void InferSpmdContext::EmplaceBackInput(DistMetaTensor input) { + int index = static_cast(inputs_.size()); inputs_.emplace_back(std::move(input)); + input_range_.emplace_back(std::pair(index, index + 1)); +} + +void InferSpmdContext::EmplaceBackInputs( + paddle::small_vector inputs) { + int index = static_cast(inputs_.size()); + input_range_.emplace_back(std::pair(index, index + inputs.size())); + inputs_.insert(inputs_.end(), + std::make_move_iterator(inputs.begin()), + std::make_move_iterator(inputs.end())); } void InferSpmdContext::EmplaceBackAttr(Attribute attr) { @@ -63,6 +74,23 @@ const Attribute& InferSpmdContext::AttrAt(size_t idx) const { return attrs_.at(idx); } +const std::pair& InferSpmdContext::InputRangeAt(size_t idx) const { + return input_range_.at(idx); +} + +const std::vector InferSpmdContext::InputsBetween( + size_t start, size_t end) const { + std::vector result; + result.reserve(end - start); + for (size_t i = start; i < end; ++i) { + auto& in = inputs_.at(i); + result.emplace_back(&in); + // result.emplace_back(in.initialized() ? &in : nullptr); + } + + return result; +} + SpmdRuleFactory& SpmdRuleFactory::Instance() { static SpmdRuleFactory g_spmd_rule_map; return g_spmd_rule_map; diff --git a/paddle/phi/core/distributed/auto_parallel/inferspmd_utils.h b/paddle/phi/core/distributed/auto_parallel/inferspmd_utils.h index bccee2bf5981a..3896bfcd6a2fe 100644 --- a/paddle/phi/core/distributed/auto_parallel/inferspmd_utils.h +++ b/paddle/phi/core/distributed/auto_parallel/inferspmd_utils.h @@ -45,9 +45,15 @@ class InferSpmdContext { void EmplaceBackInput(DistMetaTensor input); void EmplaceBackAttr(Attribute attr); + void EmplaceBackInputs( + paddle::small_vector inputs); const DistMetaTensor& InputAt(size_t idx) const; + const std::pair& InputRangeAt(size_t idx) const; + const std::vector InputsBetween(size_t start, + size_t end) const; + template AttrType AttrAt(size_t idx) const; @@ -59,6 +65,9 @@ class InferSpmdContext { // Because the attribute arguments of dygraph do not have `attr name`, // so we use vector instead of map paddle::small_vector attrs_; + // for vector arguments + paddle::small_vector, phi::kInputSmallVectorSize> + input_range_; }; using InferSpmdFn = SpmdInfo (*)(const InferSpmdContext&); @@ -98,6 +107,24 @@ struct InferSpmdFnImpl { } }; + // for vecotr slot + template + struct InferSpmdFnCallHelper&, + Tail...> { + template + static SpmdInfo Call(const InferSpmdContext& ctx, PreviousArgs&... pargs) { + static_assert(attr_idx == 0, + "InferSpmd's Input should appear before Attributes."); + + const std::pair range = ctx.InputRangeAt(in_idx); + std::vector arg = + ctx.InputsBetween(range.first, range.second); + return InferSpmdFnCallHelper::template Call( + ctx, pargs..., arg); + } + }; + #define PD_SPECIALIZE_InferSpmdFnCallHelper_FOR_ATTRIBUTE(attr_type) \ template \ struct InferSpmdFnCallHelper { \ diff --git a/paddle/phi/core/flags.cc b/paddle/phi/core/flags.cc index a7df7f3203734..e02868d5e2c1b 100644 --- a/paddle/phi/core/flags.cc +++ b/paddle/phi/core/flags.cc @@ -749,9 +749,9 @@ PHI_DEFINE_EXPORTED_int32( * [false]: not set 0D Tensor to 1D Numpy, close the hack * * Now, just set true by default in 2.5 transition time - * which will be removed in future (2.6 or 2.7) . + * which will be removed in future (2.6) . */ -PHI_DEFINE_EXPORTED_bool(set_to_1d, true, "set 0D Tensor to 1D numpy"); +PHI_DEFINE_EXPORTED_bool(set_to_1d, false, "set 0D Tensor to 1D numpy"); /** * Debug related FLAG @@ -1312,7 +1312,7 @@ PHI_DEFINE_EXPORTED_bool(enable_new_ir_in_executor_trace_run, PHI_DEFINE_EXPORTED_bool(new_ir_apply_inplace_pass, true, "Whether to apply inplace pass on lowering " - "::ir::Program to Kernel Dialect"); + "::pir::Program to Kernel Dialect"); PHI_DEFINE_EXPORTED_bool(enable_record_memory, false, "Enable memory recorder"); @@ -1329,6 +1329,12 @@ PHI_DEFINE_EXPORTED_int64(host_trace_level, "RecordEvent will works " "if host_trace_level >= level."); +PHI_DEFINE_EXPORTED_int32( + multiple_of_cupti_buffer_size, + 1, + "Multiple of the CUPTI device buffer size. If the timestamps have " + "been dropped when you are profiling, try increasing this value."); + #if defined(PADDLE_WITH_NCCL) || defined(PADDLE_WITH_RCCL) /** * Communication library related FLAG diff --git a/paddle/phi/core/kernel_factory.cc b/paddle/phi/core/kernel_factory.cc index 2e85d521c516f..d58decadfadca 100644 --- a/paddle/phi/core/kernel_factory.cc +++ b/paddle/phi/core/kernel_factory.cc @@ -14,10 +14,6 @@ #include "paddle/phi/core/kernel_factory.h" -#include -#include -#include - #include "glog/logging.h" #include "paddle/phi/core/enforce.h" #include "paddle/utils/flags.h" @@ -37,10 +33,6 @@ PHI_DEFINE_EXPORTED_bool(use_stride_kernel, true, "Whether to use strdie kernel if op support stride."); -PHI_DEFINE_EXPORTED_string(stride_kernel_blacklist, - "", - "It controls the strided kernel subset do not use."); - PD_DECLARE_int32(low_precision_op_list); PD_DECLARE_bool(enable_api_kernel_fallback); PD_DECLARE_bool(run_kp_kernel); @@ -234,26 +226,14 @@ KernelResult KernelFactory::SelectKernelOrThrowError( phi::errors::NotFound("The kernel `%s` is not registered.", kernel_name)); if (FLAGS_use_stride_kernel && use_strided_kernel) { - std::regex reg(","); - std::unordered_set elems{ - std::sregex_token_iterator(FLAGS_stride_kernel_blacklist.begin(), - FLAGS_stride_kernel_blacklist.end(), - reg, - -1), - std::sregex_token_iterator()}; - elems.erase(""); - - if (!elems.count(kernel_name)) { - auto stride_kernel_iter = iter->second.find( - {const_kernel_key.backend() == paddle::experimental::Backend::GPUDNN - ? paddle::experimental::Backend::GPU - : const_kernel_key.backend(), - phi::DataLayout::STRIDED, - const_kernel_key.dtype()}); - if (stride_kernel_iter != iter->second.end()) { - VLOG(1) << "use strided kernel, kernel_name = " << kernel_name; - return {stride_kernel_iter->second, false, true}; - } + auto stride_kernel_iter = iter->second.find( + {const_kernel_key.backend() == paddle::experimental::Backend::GPUDNN + ? paddle::experimental::Backend::GPU + : const_kernel_key.backend(), + phi::DataLayout::STRIDED, + const_kernel_key.dtype()}); + if (stride_kernel_iter != iter->second.end()) { + return {stride_kernel_iter->second, false, true}; } } diff --git a/paddle/phi/core/kernel_factory.h b/paddle/phi/core/kernel_factory.h index 610009fdb70fa..9e3c67fa9ad35 100644 --- a/paddle/phi/core/kernel_factory.h +++ b/paddle/phi/core/kernel_factory.h @@ -16,16 +16,13 @@ #include #include -#include #include #include -#include #include "paddle/phi/common/backend.h" #include "paddle/phi/common/data_type.h" #include "paddle/phi/common/layout.h" #include "paddle/phi/core/compat/convert_utils.h" #include "paddle/phi/core/compat/get_kerneltype_forvar_utils.h" -#include "paddle/phi/core/enforce.h" #include "paddle/phi/core/type_defs.h" #include "paddle/phi/core/utils/data_type.h" #include "paddle/utils/flat_hash_map.h" diff --git a/paddle/phi/core/kernel_registry.h b/paddle/phi/core/kernel_registry.h index 82d750b692e87..a9356dcfc202a 100644 --- a/paddle/phi/core/kernel_registry.h +++ b/paddle/phi/core/kernel_registry.h @@ -14,20 +14,11 @@ #pragma once -#include -#include -#include #include #include -#include #include "paddle/phi/core/custom_kernel.h" -#include "paddle/phi/core/enforce.h" -#include "paddle/phi/core/extended_tensor.h" -#include "paddle/phi/core/kernel_factory.h" #include "paddle/phi/core/kernel_utils.h" -#include "paddle/phi/core/macros.h" -#include "paddle/phi/core/type_defs.h" namespace phi { diff --git a/paddle/phi/core/kernel_utils.h b/paddle/phi/core/kernel_utils.h index 33af6abc83aa4..715b4f76392d8 100644 --- a/paddle/phi/core/kernel_utils.h +++ b/paddle/phi/core/kernel_utils.h @@ -27,7 +27,6 @@ #include "paddle/phi/core/sparse_csr_tensor.h" #include "paddle/phi/core/string_tensor.h" #include "paddle/phi/core/tensor_array.h" -#include "paddle/phi/core/type_defs.h" namespace phi { diff --git a/paddle/phi/core/meta_tensor.cc b/paddle/phi/core/meta_tensor.cc index 9b9df5c1ff4aa..53cba02ab0765 100644 --- a/paddle/phi/core/meta_tensor.cc +++ b/paddle/phi/core/meta_tensor.cc @@ -16,7 +16,7 @@ limitations under the License. */ #include "glog/logging.h" -#include "paddle/fluid/ir/dialect/paddle_dialect/ir/pd_meta_tensor.h" +#include "paddle/fluid/pir/dialect/operator/ir/meta_tensor.h" #include "paddle/phi/core/dense_tensor.h" #include "paddle/phi/core/distributed/auto_parallel/dist_tensor.h" #include "paddle/phi/core/enforce.h" @@ -87,7 +87,7 @@ void MetaTensor::set_dims(const DDim& dims) { DenseTensorUtils::GetMutableMeta(static_cast(tensor_)) ->dims = dims; } else if (phi::distributed::DistTensor::classof(tensor_)) { - static_cast(tensor_)->set_dims(dims); + static_cast(tensor_)->unsafe_set_dims(dims); } else { PADDLE_THROW(phi::errors::Unimplemented( "Unsupported setting dims for `%s`.", tensor_->type_info().name())); diff --git a/paddle/phi/core/tensor_meta.cc b/paddle/phi/core/tensor_meta.cc index 59926ed0b8c25..54c5e409aeb5b 100644 --- a/paddle/phi/core/tensor_meta.cc +++ b/paddle/phi/core/tensor_meta.cc @@ -13,7 +13,7 @@ See the License for the specific language governing permissions and limitations under the License. */ #include "paddle/phi/core/tensor_meta.h" -#include "paddle/ir/core/enforce.h" +#include "paddle/pir/core/enforce.h" namespace phi { @@ -118,12 +118,20 @@ DDim DenseTensorMeta::calc_strides(const DDim& dims) { } } -DenseTensorMeta::DenseTensorMeta() { use_gpudnn = true; } +DenseTensorMeta::DenseTensorMeta() { + use_gpudnn = true; +#ifdef PADDLE_WITH_XPU + scale_value = -1.0f; +#endif +} DenseTensorMeta::DenseTensorMeta(DataType dtype, const DDim& dims) : dims(dims), dtype(dtype) { strides = calc_strides(dims); use_gpudnn = true; +#ifdef PADDLE_WITH_XPU + scale_value = -1.0f; +#endif } DenseTensorMeta::DenseTensorMeta(DataType dtype, @@ -131,6 +139,9 @@ DenseTensorMeta::DenseTensorMeta(DataType dtype, const DDim& strides) : dims(dims), dtype(dtype), strides(strides) { use_gpudnn = true; +#ifdef PADDLE_WITH_XPU + scale_value = -1.0f; +#endif } DenseTensorMeta::DenseTensorMeta(DataType dtype, @@ -140,6 +151,9 @@ DenseTensorMeta::DenseTensorMeta(DataType dtype, : dims(dims), dtype(dtype), layout(layout), offset(offset) { strides = calc_strides(dims); use_gpudnn = true; +#ifdef PADDLE_WITH_XPU + scale_value = -1.0f; +#endif } DenseTensorMeta::DenseTensorMeta(DataType dtype, @@ -150,6 +164,9 @@ DenseTensorMeta::DenseTensorMeta(DataType dtype, : dims(dims), dtype(dtype), layout(layout), lod(lod), offset(offset) { strides = calc_strides(dims); use_gpudnn = true; +#ifdef PADDLE_WITH_XPU + scale_value = -1.0f; +#endif } DenseTensorMeta::DenseTensorMeta(const DenseTensorMeta& other) { @@ -165,6 +182,9 @@ DenseTensorMeta::DenseTensorMeta(const DenseTensorMeta& other) { } else { strides = other.strides; } +#ifdef PADDLE_WITH_XPU + scale_value = other.scale_value; +#endif } DenseTensorMeta& DenseTensorMeta::operator=(const DenseTensorMeta& other) { @@ -180,6 +200,9 @@ DenseTensorMeta& DenseTensorMeta::operator=(const DenseTensorMeta& other) { } else { strides = other.strides; } +#ifdef PADDLE_WITH_XPU + scale_value = other.scale_value; +#endif return *this; } @@ -197,7 +220,9 @@ DenseTensorMeta& DenseTensorMeta::operator=( // NOLINT } else { strides = std::move(other.strides); } - +#ifdef PADDLE_WITH_XPU + scale_value = other.scale_value; +#endif return *this; } diff --git a/paddle/phi/core/tensor_meta.h b/paddle/phi/core/tensor_meta.h index ecd746e10037f..2575b51e49fe8 100644 --- a/paddle/phi/core/tensor_meta.h +++ b/paddle/phi/core/tensor_meta.h @@ -82,13 +82,23 @@ struct DenseTensorMeta { LoD lod; size_t offset{0}; DDim strides; + +#ifdef PADDLE_WITH_XPU + // for per tensor scale + float scale_value{-1.0f}; +#endif }; inline bool operator==(const DenseTensorMeta& lhs, const DenseTensorMeta& rhs) { return (lhs.is_scalar == rhs.is_scalar) && lhs.use_gpudnn == rhs.use_gpudnn && (lhs.dims == rhs.dims) && (lhs.dtype == rhs.dtype) && (lhs.layout == rhs.layout) && (lhs.lod == rhs.lod) && +#ifdef PADDLE_WITH_XPU + (lhs.offset == rhs.offset) && (lhs.strides == rhs.strides) && + (lhs.scale_value == rhs.scale_value); +#else (lhs.offset == rhs.offset) && (lhs.strides == rhs.strides); +#endif } struct StringTensorMeta { diff --git a/paddle/phi/core/utils/type_info.cc b/paddle/phi/core/utils/type_info.cc index 99b134b6e7960..2cb903fde7310 100644 --- a/paddle/phi/core/utils/type_info.cc +++ b/paddle/phi/core/utils/type_info.cc @@ -14,7 +14,7 @@ limitations under the License. */ #include -#include "paddle/fluid/ir/dialect/paddle_dialect/ir/pd_meta_tensor.h" +#include "paddle/fluid/pir/dialect/operator/ir/meta_tensor.h" #include "paddle/phi/backends/cpu/cpu_context.h" #include "paddle/phi/backends/custom/custom_context.h" #include "paddle/phi/backends/gpu/gpu_context.h" diff --git a/paddle/phi/infermeta/binary.cc b/paddle/phi/infermeta/binary.cc index a66790d0ce6cd..2fd87760378fc 100644 --- a/paddle/phi/infermeta/binary.cc +++ b/paddle/phi/infermeta/binary.cc @@ -2169,77 +2169,11 @@ void MatmulInferMeta(const MetaTensor& x, auto ddim_out = phi::make_ddim(new_dims); out->set_dims(ddim_out); - out->set_dtype(x.dtype()); - out->set_layout(x.layout()); -} - -void MatmulInt8InferMeta(const MetaTensor& x, - const MetaTensor& y, - bool trans_x, - bool trans_y, - MetaTensor* out) { - std::vector dims_x = phi::vectorize(x.dims()); - std::vector dims_y = phi::vectorize(y.dims()); - auto ndims_x = dims_x.size(); - auto ndims_y = dims_y.size(); - PADDLE_ENFORCE_GT(ndims_x, - 0UL, - phi::errors::InvalidArgument( - "The Input(x) dims size must be greater than 0," - " but reviced dims size is 0. ")); - PADDLE_ENFORCE_GT(ndims_y, - 0UL, - phi::errors::InvalidArgument( - "The Input(y) dims size must be greater than 0," - " but reviced dims size is 0. ")); - - bool x_broadcasted = false, y_broadcasted = false; - if (ndims_x == 1) { - dims_x.insert(dims_x.begin(), 1); - ndims_x = 2; - x_broadcasted = true; - } - - if (ndims_y == 1) { - dims_y.push_back(1); - ndims_y = 2; - y_broadcasted = true; - } - - size_t M, N; - if (trans_x) { - M = dims_x[ndims_x - 1]; - } else { - M = dims_x[ndims_x - 2]; - } - if (trans_y) { - N = dims_y[ndims_y - 2]; - } else { - N = dims_y[ndims_y - 1]; - } - - std::vector new_dims; - if (ndims_x > ndims_y) { - new_dims.assign(dims_x.begin(), dims_x.end() - 2); - } else if (ndims_x < ndims_y) { - new_dims.assign(dims_y.begin(), dims_y.end() - 2); + if (x.dtype() == phi::DataType::INT8) { + out->set_dtype(phi::DataType::INT32); } else { - new_dims.reserve(ndims_x); - for (size_t i = 0; i < ndims_x - 2; ++i) { - new_dims.push_back(std::max(dims_x[i], dims_y[i])); - } - } - if (!x_broadcasted) { - new_dims.push_back(M); // NOLINT - } - if (!y_broadcasted) { - new_dims.push_back(N); // NOLINT + out->set_dtype(x.dtype()); } - - auto ddim_out = phi::make_ddim(new_dims); - - out->set_dims(ddim_out); - out->set_dtype(phi::DataType::INT32); out->set_layout(x.layout()); } @@ -2314,7 +2248,11 @@ void MatmulWithFlattenInferMeta(const MetaTensor& x, } out->set_dims(phi::make_ddim(output_dims)); - out->set_dtype(x.dtype()); + if (x.dtype() == phi::DataType::INT8) { + out->set_dtype(phi::DataType::INT32); + } else { + out->set_dtype(x.dtype()); + } out->share_lod(x); } diff --git a/paddle/phi/infermeta/binary.h b/paddle/phi/infermeta/binary.h index 887da467e07b1..94d8bb606ea5d 100644 --- a/paddle/phi/infermeta/binary.h +++ b/paddle/phi/infermeta/binary.h @@ -347,12 +347,6 @@ void MatmulInferMeta(const MetaTensor& x, bool trans_y, MetaTensor* out); -void MatmulInt8InferMeta(const MetaTensor& x, - const MetaTensor& y, - bool trans_x, - bool trans_y, - MetaTensor* out); - void MatmulWithFlattenInferMeta(const MetaTensor& x, const MetaTensor& y, int x_num_col_dims, diff --git a/paddle/phi/infermeta/nullary.cc b/paddle/phi/infermeta/nullary.cc index d5da3a2f8bc87..1c57e2fae92ac 100644 --- a/paddle/phi/infermeta/nullary.cc +++ b/paddle/phi/infermeta/nullary.cc @@ -223,6 +223,11 @@ void RecvV2InferMeta(const int ring_id, out->set_dtype(dtype); } +void SeedInferMeta(int seed, MetaTensor* out) { + out->set_dims(phi::make_ddim({1})); + out->set_dtype(DataType::INT32); +} + void TruncatedGaussianRandomInferMeta(const std::vector& shape, float mean, float std, diff --git a/paddle/phi/infermeta/nullary.h b/paddle/phi/infermeta/nullary.h index bc73942c8ec1c..2f9c9a69a13f1 100644 --- a/paddle/phi/infermeta/nullary.h +++ b/paddle/phi/infermeta/nullary.h @@ -83,6 +83,8 @@ void RecvV2InferMeta(const int ring_id, DataType dtype, MetaTensor* out); +void SeedInferMeta(int seed, MetaTensor* out); + void TruncatedGaussianRandomInferMeta(const std::vector& shape, float mean, float std, diff --git a/paddle/phi/infermeta/spmd_rules/default_data_parallel.cc b/paddle/phi/infermeta/spmd_rules/default_data_parallel.cc new file mode 100644 index 0000000000000..4359534dea939 --- /dev/null +++ b/paddle/phi/infermeta/spmd_rules/default_data_parallel.cc @@ -0,0 +1,164 @@ +/* Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include "paddle/phi/infermeta/spmd_rules/default_data_parallel.h" + +#include "glog/logging.h" + +#include "paddle/phi/core/distributed/auto_parallel/dist_attr.h" +#include "paddle/phi/core/distributed/auto_parallel/inferspmd_utils.h" +#include "paddle/phi/core/distributed/auto_parallel/utils.h" + +namespace phi { +namespace distributed { + +using phi::distributed::auto_parallel::str_join; + +////////////////// Utils Functions ////////////////// +std::vector GetDefaultDataParallelDimsmapping( + const int64_t batch_axis_dim, const int ndim) { + std::vector dims_mapping(ndim, -1); + dims_mapping[0] = batch_axis_dim; + return dims_mapping; +} + +////////////////// InferMeta(Contains SPMD) Functions ////////////////// + +SpmdInfo DefaultDataParallelSpmdInferForward( + const std::vector& ins, + const std::vector& outs) { + // step1: Build Einsum Notation for input tensor's batch axis + int64_t ninputs = ins.size(); + int64_t noutputs = outs.size(); + std::vector>> axes_sharding_info; + std::string batch_axis = "b"; + + for (int64_t i = 0; i < ninputs; ++i) { + axes_sharding_info.push_back( + {batch_axis, {ins[i]->dist_attr().dims_mapping()[0]}}); + } + + // Step2: Sharding Merge + std::unordered_map axis_to_dim_map = + ShardingMergeForTensors(axes_sharding_info); + int64_t batch_axis_dim = axis_to_dim_map[batch_axis]; + + // Step3: Infer Output's Batch Axis Dims Mapping. + std::vector output_dist_attrs; + for (int64_t i = 0; i < noutputs; i++) { + int ndim = outs[i]->dims().size(); + TensorDistAttr dist_attr_dst = + CopyTensorDistAttrForOutput(ins[0]->dist_attr()); + std::vector dst_dims_maping = + GetDefaultDataParallelDimsmapping(batch_axis_dim, ndim); + dist_attr_dst.set_dims_mapping(dst_dims_maping); + output_dist_attrs.emplace_back(dist_attr_dst); + } + + // Step4: Merge and get Inputs' Batch Axis New Dims Mapping. + std::vector dst_input_dist_attrs; + for (int64_t i = 0; i < ninputs; i++) { + int ndim = ins[i]->dims().size(); + TensorDistAttr dist_attr_dst = + CopyTensorDistAttrForOutput(ins[i]->dist_attr()); + std::vector dst_dims_maping = + GetDefaultDataParallelDimsmapping(batch_axis_dim, ndim); + dist_attr_dst.set_dims_mapping(dst_dims_maping); + dst_input_dist_attrs.emplace_back(dist_attr_dst); + } + + VLOG(4) << "DefaultDataParallelSpmd InferForward:"; + for (int64_t i = 0; i < ninputs; i++) { + VLOG(4) << "Input" << std::to_string(i) << " shape: [" + << str_join(phi::vectorize(ins[i]->dims())) << "] " + << "src_dims_mapping: [" + << str_join(ins[i]->dist_attr().dims_mapping()) << "] " + << "dst_dims_mapping: [" + << str_join(dst_input_dist_attrs[i].dims_mapping()) << "]"; + } + + for (int64_t i = 0; i < noutputs; i++) { + VLOG(4) << "Output" << std::to_string(i) << " shape: [" + << str_join(phi::vectorize(outs[i]->dims())) << "] " + << "dst_dims_mapping: [" + << str_join(output_dist_attrs[i].dims_mapping()) << "]"; + } + + return {dst_input_dist_attrs, output_dist_attrs}; +} +SpmdInfo DefaultDataParallelSpmdInferBackward( + const std::vector& ins, + const std::vector& outs) { + // step1: Build Einsum Notation for input tensor's batch axis + int64_t ninputs = ins.size(); + int64_t noutputs = outs.size(); + std::vector>> axes_sharding_info; + std::string batch_axis = "b"; + + for (int64_t i = 0; i < noutputs; ++i) { + axes_sharding_info.push_back( + {batch_axis, {outs[i]->dist_attr().dims_mapping()[0]}}); + } + + // Step2: Sharding Merge + std::unordered_map axis_to_dim_map = + ShardingMergeForTensors(axes_sharding_info); + int64_t batch_axis_dim = axis_to_dim_map[batch_axis]; + + // Step3: Infer Output's Batch Axis Dims Mapping. + std::vector output_dist_attrs; + for (int64_t i = 0; i < noutputs; i++) { + int ndim = outs[i]->dims().size(); + TensorDistAttr dist_attr_dst = + CopyTensorDistAttrForOutput(outs[i]->dist_attr()); + std::vector dst_dims_maping = + GetDefaultDataParallelDimsmapping(batch_axis_dim, ndim); + dist_attr_dst.set_dims_mapping(dst_dims_maping); + output_dist_attrs.emplace_back(dist_attr_dst); + } + + // Step4: Merge and get Inputs' Batch Axis New Dims Mapping. + std::vector dst_input_dist_attrs; + for (int64_t i = 0; i < ninputs; i++) { + int ndim = ins[i]->dims().size(); + TensorDistAttr dist_attr_dst = + CopyTensorDistAttrForOutput(ins[i]->dist_attr()); + std::vector dst_dims_maping = + GetDefaultDataParallelDimsmapping(batch_axis_dim, ndim); + dist_attr_dst.set_dims_mapping(dst_dims_maping); + dst_input_dist_attrs.emplace_back(dist_attr_dst); + } + + VLOG(4) << "DefaultDataParallelSpmd InferBackward:"; + for (int64_t i = 0; i < noutputs; i++) { + VLOG(4) << "Output" << std::to_string(i) << " shape: [" + << str_join(phi::vectorize(outs[i]->dims())) << "] " + << "src_dims_mapping: [" + << str_join(outs[i]->dist_attr().dims_mapping()) << "] " + << "dst_dims_mapping: [" + << str_join(output_dist_attrs[i].dims_mapping()) << "]"; + } + + for (int64_t i = 0; i < ninputs; i++) { + VLOG(4) << "Input" << std::to_string(i) << " shape: [" + << str_join(phi::vectorize(ins[i]->dims())) << "] " + << "dst_dims_mapping: [" + << str_join(dst_input_dist_attrs[i].dims_mapping()) << "]"; + } + + return {dst_input_dist_attrs, output_dist_attrs}; +} + +} // namespace distributed +} // namespace phi diff --git a/paddle/phi/infermeta/spmd_rules/default_data_parallel.h b/paddle/phi/infermeta/spmd_rules/default_data_parallel.h new file mode 100644 index 0000000000000..25fa3b65e50a0 --- /dev/null +++ b/paddle/phi/infermeta/spmd_rules/default_data_parallel.h @@ -0,0 +1,67 @@ +/* Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#pragma once + +#include + +#include "paddle/phi/core/distributed/auto_parallel/dist_meta_tensor.h" +#include "paddle/phi/core/distributed/type_defs.h" +#include "paddle/phi/infermeta/spmd_rules/utils.h" + +namespace phi { +namespace distributed { +/** + * A **hack** rule with a strong assumtion that the first dimension of + * all the input and ouput tensors is the batch dimension (broadcast dimension), + * therefore, if any tensor's first dimension is sharded, the sharding would be + * propagating to all the other tensors (for tensor first dimension). All the + * other axes of tensors would be set as unshard (-1). + * + * + * This rule is used to support emerging op for hybrid parallelism quickly, and + * once there is a specific rule for that op, we should remove that op from + * this rule. + * + * Vector of input tensors and output tensors used as argumnets (for both + * inferfw & inferbw) to support any kind of op. + * + */ +SpmdInfo DefaultDataParallelSpmdInferForward( + const std::vector& ins, + const std::vector& outs); + +SpmdInfo DefaultDataParallelSpmdInferBackward( + const std::vector& ins, + const std::vector& outs); + +// For phi api +template +SpmdInfo PhiDefaultDataParallelSpmdInferForward(const Args&... args) { + return detail::PhiSpmdVariadicArgumentParser< + DefaultDataParallelSpmdInferForward>() + .apply(args...) + .InferForward(); +} + +template +SpmdInfo PhiDefaultDataParallelSpmdInferBackward(const Args&... args) { + return detail::PhiSpmdVariadicArgumentParser< + DefaultDataParallelSpmdInferBackward>() + .apply(args...) + .InferBackward(); +} + +} // namespace distributed +} // namespace phi diff --git a/paddle/phi/infermeta/spmd_rules/matmul.cc b/paddle/phi/infermeta/spmd_rules/matmul.cc index 088f9ab16363a..a29f23b88038c 100644 --- a/paddle/phi/infermeta/spmd_rules/matmul.cc +++ b/paddle/phi/infermeta/spmd_rules/matmul.cc @@ -114,10 +114,10 @@ void FillMatmulOperandNotation(const int x_ndim, ////////////////// InferMeta(Contains SPMD) Functions ////////////////// -SpmdInfo MatmulSpmdInferForward(const DistMetaTensor& x, - const DistMetaTensor& y, - bool trans_x, - bool trans_y) { +SpmdInfo MatmulInferSpmd(const DistMetaTensor& x, + const DistMetaTensor& y, + bool trans_x, + bool trans_y) { // Step0: verify input args based on matmul logic auto x_shape = phi::vectorize(x.dims()); auto y_shape = phi::vectorize(y.dims()); @@ -221,11 +221,11 @@ SpmdInfo MatmulSpmdInferForward(const DistMetaTensor& x, return {{x_dist_attr_dst, y_dist_attr_dst}, {output_dist_attr_dst}}; } -SpmdInfo MatmulSpmdInferBackward(const DistMetaTensor& x, - const DistMetaTensor& y, - const DistMetaTensor& out, - bool trans_x, - bool trans_y) { +SpmdInfo MatmulInferSpmdReverse(const DistMetaTensor& x, + const DistMetaTensor& y, + const DistMetaTensor& out, + bool trans_x, + bool trans_y) { auto out_shape = phi::vectorize(out.dims()); int out_ndim = out_shape.size(); diff --git a/paddle/phi/infermeta/spmd_rules/matmul.h b/paddle/phi/infermeta/spmd_rules/matmul.h index 64cfba26a7445..6bb36f4bd3d34 100644 --- a/paddle/phi/infermeta/spmd_rules/matmul.h +++ b/paddle/phi/infermeta/spmd_rules/matmul.h @@ -22,16 +22,16 @@ limitations under the License. */ namespace phi { namespace distributed { -SpmdInfo MatmulSpmdInferForward(const DistMetaTensor& x, +SpmdInfo MatmulInferSpmd(const DistMetaTensor& x, + const DistMetaTensor& y, + bool trans_x, + bool trans_y); + +SpmdInfo MatmulInferSpmdReverse(const DistMetaTensor& x, const DistMetaTensor& y, + const DistMetaTensor& out, bool trans_x, bool trans_y); -SpmdInfo MatmulSpmdInferBackward(const DistMetaTensor& x, - const DistMetaTensor& y, - const DistMetaTensor& out, - bool trans_x, - bool trans_y); - } // namespace distributed } // namespace phi diff --git a/paddle/phi/infermeta/spmd_rules/replicated.cc b/paddle/phi/infermeta/spmd_rules/replicated.cc new file mode 100644 index 0000000000000..55aa9bf61e0e4 --- /dev/null +++ b/paddle/phi/infermeta/spmd_rules/replicated.cc @@ -0,0 +1,136 @@ +/* Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include "paddle/phi/infermeta/spmd_rules/replicated.h" + +#include "glog/logging.h" + +#include "paddle/phi/core/distributed/auto_parallel/dist_attr.h" +#include "paddle/phi/core/distributed/auto_parallel/inferspmd_utils.h" +#include "paddle/phi/core/distributed/auto_parallel/utils.h" + +namespace phi { +namespace distributed { + +using phi::distributed::auto_parallel::str_join; + +////////////////// Utils Functions ////////////////// +std::vector GetReplicatedDimsmapping(const int ndim) { + std::vector dims_mapping(ndim, -1); + return dims_mapping; +} + +////////////////// InferMeta(Contains SPMD) Functions ////////////////// +SpmdInfo ReplicatedSpmdInferForward( + const std::vector& ins, + const std::vector& outs) { + // step1: Build Einsum Notation for input tensor's batch axis + int64_t ninputs = ins.size(); + int64_t noutputs = outs.size(); + + // Step2: Unshard Output's Dims Mapping. + std::vector output_dist_attrs; + for (int64_t i = 0; i < noutputs; i++) { + VLOG(4) << outs[i]->dist_attr().to_string(); + VLOG(4) << outs[i]->dims().to_str(); + int ndim = outs[i]->dims().size(); + TensorDistAttr dist_attr_dst = + CopyTensorDistAttrForOutput(ins[0]->dist_attr()); + std::vector dst_dims_maping = GetReplicatedDimsmapping(ndim); + dist_attr_dst.set_dims_mapping(dst_dims_maping); + output_dist_attrs.emplace_back(dist_attr_dst); + } + + // Step3: Merge and get Inputs' Batch Axis New Dims Mapping. + std::vector dst_input_dist_attrs; + for (int64_t i = 0; i < ninputs; i++) { + int ndim = ins[i]->dims().size(); + TensorDistAttr dist_attr_dst = + CopyTensorDistAttrForOutput(ins[i]->dist_attr()); + std::vector dst_dims_maping = GetReplicatedDimsmapping(ndim); + dist_attr_dst.set_dims_mapping(dst_dims_maping); + dst_input_dist_attrs.emplace_back(dist_attr_dst); + } + + VLOG(4) << "ReplicatedSpmd InferForward:"; + for (int64_t i = 0; i < ninputs; i++) { + VLOG(4) << "Input" << std::to_string(i) << " shape: [" + << str_join(phi::vectorize(ins[i]->dims())) << "] " + << "src_dims_mapping: [" + << str_join(ins[i]->dist_attr().dims_mapping()) << "] " + << "dst_dims_mapping: [" + << str_join(dst_input_dist_attrs[i].dims_mapping()) << "]"; + } + + for (int64_t i = 0; i < noutputs; i++) { + VLOG(4) << "Output" << std::to_string(i) << " shape: [" + << str_join(phi::vectorize(outs[i]->dims())) << "] " + << "dst_dims_mapping: [" + << str_join(output_dist_attrs[i].dims_mapping()) << "]"; + } + + return {dst_input_dist_attrs, output_dist_attrs}; +} + +SpmdInfo ReplicatedSpmdInferBackward( + const std::vector& ins, + const std::vector& outs) { + // step1: Build Einsum Notation for input tensor's batch axis + int64_t ninputs = ins.size(); + int64_t noutputs = outs.size(); + + // Step2: Unshard Output's Dims Mapping. + std::vector output_dist_attrs; + for (int64_t i = 0; i < noutputs; i++) { + int ndim = outs[i]->dims().size(); + TensorDistAttr dist_attr_dst = + CopyTensorDistAttrForOutput(outs[i]->dist_attr()); + std::vector dst_dims_maping = GetReplicatedDimsmapping(ndim); + dist_attr_dst.set_dims_mapping(dst_dims_maping); + output_dist_attrs.emplace_back(dist_attr_dst); + } + + // Step3: Merge and get Inputs' Batch Axis New Dims Mapping. + std::vector dst_input_dist_attrs; + for (int64_t i = 0; i < ninputs; i++) { + int ndim = ins[i]->dims().size(); + TensorDistAttr dist_attr_dst = + CopyTensorDistAttrForOutput(ins[i]->dist_attr()); + std::vector dst_dims_maping = GetReplicatedDimsmapping(ndim); + dist_attr_dst.set_dims_mapping(dst_dims_maping); + dst_input_dist_attrs.emplace_back(dist_attr_dst); + } + + VLOG(4) << "ReplicatedSpmd InferBackward:"; + for (int64_t i = 0; i < noutputs; i++) { + VLOG(4) << "Output" << std::to_string(i) << " shape: [" + << str_join(phi::vectorize(outs[i]->dims())) << "] " + << "src_dims_mapping: [" + << str_join(outs[i]->dist_attr().dims_mapping()) << "] " + << "dst_dims_mapping: [" + << str_join(output_dist_attrs[i].dims_mapping()) << "]"; + } + + for (int64_t i = 0; i < ninputs; i++) { + VLOG(4) << "Input" << std::to_string(i) << " shape: [" + << str_join(phi::vectorize(ins[i]->dims())) << "] " + << "dst_dims_mapping: [" + << str_join(dst_input_dist_attrs[i].dims_mapping()) << "]"; + } + + return {dst_input_dist_attrs, output_dist_attrs}; +} + +} // namespace distributed +} // namespace phi diff --git a/paddle/phi/infermeta/spmd_rules/replicated.h b/paddle/phi/infermeta/spmd_rules/replicated.h new file mode 100644 index 0000000000000..7b2ea330be3eb --- /dev/null +++ b/paddle/phi/infermeta/spmd_rules/replicated.h @@ -0,0 +1,61 @@ +/* Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#pragma once + +#include + +#include "paddle/phi/core/distributed/auto_parallel/dist_meta_tensor.h" +#include "paddle/phi/core/distributed/type_defs.h" +#include "paddle/phi/infermeta/spmd_rules/utils.h" + +namespace phi { +namespace distributed { +/** + * A Bottom Line Rule that enforces input(s) and output(s) of the Op to be + * replicated among the given mesh. + * + * This rule is used to support any op that have not been assign a specific rule + * in auto parallel, and once there is a specific rule for that op, replicated + * rule would not effect that op any more. + * + * Vector of input tensors and output tensors used as argumnets (for both + * inferfw & inferbw) to support any kind of op. + * + */ +SpmdInfo ReplicatedSpmdInferForward( + const std::vector& ins, + const std::vector& outs); + +SpmdInfo ReplicatedSpmdInferBackward( + const std::vector& ins, + const std::vector& outs); + +// For phi api +template +SpmdInfo PhiReplicatedSpmdInferForward(const Args&... args) { + return detail::PhiSpmdVariadicArgumentParser() + .apply(args...) + .InferForward(); +} + +template +SpmdInfo PhiReplicatedSpmdInferBackward(const Args&... args) { + return detail::PhiSpmdVariadicArgumentParser() + .apply(args...) + .InferBackward(); +} + +} // namespace distributed +} // namespace phi diff --git a/paddle/phi/infermeta/spmd_rules/rules.h b/paddle/phi/infermeta/spmd_rules/rules.h index 5ec2f212ec65b..84eb9bd552f17 100644 --- a/paddle/phi/infermeta/spmd_rules/rules.h +++ b/paddle/phi/infermeta/spmd_rules/rules.h @@ -16,7 +16,9 @@ limitations under the License. */ #include "paddle/phi/core/distributed/auto_parallel/inferspmd_utils.h" +#include "paddle/phi/infermeta/spmd_rules/default_data_parallel.h" #include "paddle/phi/infermeta/spmd_rules/matmul.h" +#include "paddle/phi/infermeta/spmd_rules/replicated.h" /** * Design Notes: @@ -40,8 +42,20 @@ namespace distributed { // matmul rule PD_REGISTER_SPMD_RULE(matmul, - PD_INFER_SPMD(phi::distributed::MatmulSpmdInferForward), - PD_INFER_SPMD(phi::distributed::MatmulSpmdInferBackward)); + PD_INFER_SPMD(phi::distributed::MatmulInferSpmd), + PD_INFER_SPMD(phi::distributed::MatmulInferSpmdReverse)); + +// default data parallel rule +PD_REGISTER_SPMD_RULE( + unsqueeze, + PD_INFER_SPMD(phi::distributed::DefaultDataParallelSpmdInferForward), + PD_INFER_SPMD(phi::distributed::DefaultDataParallelSpmdInferBackward)); + +// replicated rule /* for unitest */ +PD_REGISTER_SPMD_RULE( + replicated, + PD_INFER_SPMD(phi::distributed::ReplicatedSpmdInferForward), + PD_INFER_SPMD(phi::distributed::ReplicatedSpmdInferBackward)); } // namespace distributed } // namespace phi diff --git a/paddle/phi/infermeta/spmd_rules/utils.cc b/paddle/phi/infermeta/spmd_rules/utils.cc index 2252de98a78b3..e7a3dac52ac1d 100644 --- a/paddle/phi/infermeta/spmd_rules/utils.cc +++ b/paddle/phi/infermeta/spmd_rules/utils.cc @@ -137,6 +137,8 @@ TensorDistAttr CopyTensorDistAttrForOutput( new_dist_attr.set_batch_dim(src_dist_attr.batch_dim()); new_dist_attr.set_dynamic_dims(src_dist_attr.dynamic_dims()); // new_dist_attr.set_annotated(false); TODO unset field is false by default. + new_dist_attr.clean_partial_status(); // in partial-stage I, partial is allow + // to propagate return new_dist_attr; } diff --git a/paddle/phi/infermeta/spmd_rules/utils.h b/paddle/phi/infermeta/spmd_rules/utils.h index 5e3c3a3d0961c..e35b9cc792583 100644 --- a/paddle/phi/infermeta/spmd_rules/utils.h +++ b/paddle/phi/infermeta/spmd_rules/utils.h @@ -19,6 +19,10 @@ limitations under the License. */ #include #include +#include "paddle/phi/core/attribute.h" +#include "paddle/phi/core/distributed/auto_parallel/dist_meta_tensor.h" +#include "paddle/phi/core/distributed/type_defs.h" + namespace phi { namespace distributed { class TensorDistAttr; @@ -61,5 +65,74 @@ std::vector ResoluteOutputPartialDimension( const std::unordered_map& axis_to_dim_map, const std::string& tensor_axes); +// Adaptor for variadic arguments +template +struct ArgsIterator { + template + inline Functor& apply() { + return self(); + } + + template + inline Functor& apply(T&& arg, Args&&... args) { + self()(std::forward(arg)); + if (self().short_circuit()) { + return self(); + } else { + return apply(std::forward(args)...); + } + } + + constexpr bool short_circuit() const { return false; } + + private: + inline Functor& self() { return *static_cast(this); } +}; + +using SpmdFn = SpmdInfo (*)(const std::vector& ins, + const std::vector& outs); + +namespace detail { +template +struct PhiSpmdVariadicArgumentParser + : public ArgsIterator> { + std::vector inputs; + std::vector outputs; + std::vector attrs; + + // deal with inputs + void operator()(const DistMetaTensor& x) { inputs.emplace_back(&x); } + + void operator()(const std::vector& x) { + for (auto t : x) { + inputs.emplace_back(t); + } + } + + template + void operator()(AttrType x) { + attrs.emplace_back(x); + } + + // deal with outputs + void operator()(DistMetaTensor* out) { outputs.emplace_back(out); } + + void operator()(std::vector out) { + for (auto t : out) { + outputs.emplace_back(t); + } + } + + SpmdInfo InferForward() { + return Fn(inputs, outputs); + // return Fn(inputs, outputs, attrs); + } + + SpmdInfo InferBackward() { + return Fn(inputs, outputs); + // return Fn(inputs, outputs, attrs); + } +}; +} // namespace detail } // namespace distributed } // namespace phi diff --git a/paddle/phi/kernels/CMakeLists.txt b/paddle/phi/kernels/CMakeLists.txt index cc8df692a1267..cc2594c9a720a 100644 --- a/paddle/phi/kernels/CMakeLists.txt +++ b/paddle/phi/kernels/CMakeLists.txt @@ -83,6 +83,16 @@ if(WITH_CUTLASS) ) endif() + execute_process( + COMMAND + ${CMAKE_COMMAND} -E make_directory + "${CMAKE_CURRENT_SOURCE_DIR}/fusion/cutlass/cutlass_kernels/fpA_intB_gemm/autogen" + COMMAND ${PYTHON_EXECUTABLE} generic_mixed_gemm_kernelLauncher.py + --cuda_arch "${NVCC_ARCH_BIN}" + WORKING_DIRECTORY + "${CMAKE_CURRENT_SOURCE_DIR}/fusion/cutlass/cutlass_kernels/fpA_intB_gemm" + ) + file( GLOB cutlass_cu RELATIVE "${CMAKE_CURRENT_SOURCE_DIR}" @@ -90,7 +100,9 @@ if(WITH_CUTLASS) "fusion/cutlass/conv2d/*.cu" "fusion/cutlass/*.cu" "fusion/cutlass/memory_efficient_attention/autogen/impl/*.cu" - "fusion/cutlass/memory_efficient_attention/autogen_variable/impl/*.cu") + "fusion/cutlass/memory_efficient_attention/autogen_variable/impl/*.cu" + "fusion/cutlass/cutlass_kernels/fpA_intB_gemm/autogen/*.cu" + "fusion/cutlass/cutlass_kernels/fpA_intB_gemm/*.cu") list(APPEND kernel_cu ${cutlass_cu}) endif() diff --git a/paddle/phi/kernels/autotune/auto_tune_base.h b/paddle/phi/kernels/autotune/auto_tune_base.h index a6a37272840af..438991ef6fd62 100644 --- a/paddle/phi/kernels/autotune/auto_tune_base.h +++ b/paddle/phi/kernels/autotune/auto_tune_base.h @@ -106,7 +106,7 @@ class AutoTuneBase { float min_time = std::numeric_limits::max(); // Time cost test estabulished in default stream. - for (int i = 0; i < kernels_.size(); ++i) { + for (size_t i = 0; i < kernels_.size(); ++i) { auto time = RunAndMeasureKernel(ctx, i, args...); if (time < min_time) { min_time = time; diff --git a/paddle/phi/kernels/cpu/activation_grad_kernel.cc b/paddle/phi/kernels/cpu/activation_grad_kernel.cc index b3203332ec7d1..d3cf1cbcb34c1 100644 --- a/paddle/phi/kernels/cpu/activation_grad_kernel.cc +++ b/paddle/phi/kernels/cpu/activation_grad_kernel.cc @@ -283,14 +283,14 @@ PD_REGISTER_KERNEL( PD_REGISTER_ACTIVATION_GRAD_KERNEL_WITH_COMPLEX(sin_grad, SinGradKernel) PD_REGISTER_ACTIVATION_GRAD_KERNEL_WITH_COMPLEX(cos_grad, CosGradKernel) PD_REGISTER_ACTIVATION_GRAD_KERNEL_WITH_COMPLEX(tan_grad, TanGradKernel) -PD_REGISTER_ACTIVATION_GRAD_KERNEL(acos_grad, AcosGradKernel) -PD_REGISTER_ACTIVATION_GRAD_KERNEL(asin_grad, AsinGradKernel) -PD_REGISTER_ACTIVATION_GRAD_KERNEL(atan_grad, AtanGradKernel) -PD_REGISTER_ACTIVATION_GRAD_KERNEL(sinh_grad, SinhGradKernel) -PD_REGISTER_ACTIVATION_GRAD_KERNEL(cosh_grad, CoshGradKernel) -PD_REGISTER_ACTIVATION_GRAD_KERNEL(asinh_grad, AsinhGradKernel) -PD_REGISTER_ACTIVATION_GRAD_KERNEL(acosh_grad, AcoshGradKernel) -PD_REGISTER_ACTIVATION_GRAD_KERNEL(atanh_grad, AtanhGradKernel) +PD_REGISTER_ACTIVATION_GRAD_KERNEL_WITH_COMPLEX(acos_grad, AcosGradKernel) +PD_REGISTER_ACTIVATION_GRAD_KERNEL_WITH_COMPLEX(asin_grad, AsinGradKernel) +PD_REGISTER_ACTIVATION_GRAD_KERNEL_WITH_COMPLEX(atan_grad, AtanGradKernel) +PD_REGISTER_ACTIVATION_GRAD_KERNEL_WITH_COMPLEX(sinh_grad, SinhGradKernel) +PD_REGISTER_ACTIVATION_GRAD_KERNEL_WITH_COMPLEX(cosh_grad, CoshGradKernel) +PD_REGISTER_ACTIVATION_GRAD_KERNEL_WITH_COMPLEX(asinh_grad, AsinhGradKernel) +PD_REGISTER_ACTIVATION_GRAD_KERNEL_WITH_COMPLEX(acosh_grad, AcoshGradKernel) +PD_REGISTER_ACTIVATION_GRAD_KERNEL_WITH_COMPLEX(atanh_grad, AtanhGradKernel) PD_REGISTER_ACTIVATION_GRAD_KERNEL_WITH_COMPLEX(tanh_grad, TanhGradKernel) PD_REGISTER_ACTIVATION_GRAD_KERNEL(hardtanh_grad, HardTanhGradKernel) PD_REGISTER_ACTIVATION_GRAD_KERNEL(leaky_relu_grad, LeakyReluGradKernel) @@ -340,7 +340,9 @@ PD_REGISTER_KERNEL(exp_grad, float, double, int, - int64_t) {} + int64_t, + phi::dtype::complex, + phi::dtype::complex) {} PD_REGISTER_KERNEL(expm1_grad, CPU, @@ -348,7 +350,9 @@ PD_REGISTER_KERNEL(expm1_grad, phi::Expm1GradKernel, float, double, - phi::dtype::float16) {} + phi::dtype::float16, + phi::dtype::complex, + phi::dtype::complex) {} PD_REGISTER_KERNEL( logit_grad, CPU, ALL_LAYOUT, phi::LogitGradKernel, float, double) {} diff --git a/paddle/phi/kernels/cpu/activation_kernel.cc b/paddle/phi/kernels/cpu/activation_kernel.cc index 8a554470dea39..66480018a5273 100644 --- a/paddle/phi/kernels/cpu/activation_kernel.cc +++ b/paddle/phi/kernels/cpu/activation_kernel.cc @@ -179,14 +179,14 @@ PD_REGISTER_KERNEL(relu, CPU, ALL_LAYOUT, phi::ReluKernel, float, double) {} PD_REGISTER_ACTIVATION_KERNEL_WITH_COMPLEX(sin, SinKernel) PD_REGISTER_ACTIVATION_KERNEL_WITH_COMPLEX(cos, CosKernel) PD_REGISTER_ACTIVATION_KERNEL_WITH_COMPLEX(tan, TanKernel) -PD_REGISTER_ACTIVATION_KERNEL(acos, AcosKernel) -PD_REGISTER_ACTIVATION_KERNEL(asin, AsinKernel) -PD_REGISTER_ACTIVATION_KERNEL(atan, AtanKernel) -PD_REGISTER_ACTIVATION_KERNEL(sinh, SinhKernel) -PD_REGISTER_ACTIVATION_KERNEL(cosh, CoshKernel) -PD_REGISTER_ACTIVATION_KERNEL(asinh, AsinhKernel) -PD_REGISTER_ACTIVATION_KERNEL(acosh, AcoshKernel) -PD_REGISTER_ACTIVATION_KERNEL(atanh, AtanhKernel) +PD_REGISTER_ACTIVATION_KERNEL_WITH_COMPLEX(acos, AcosKernel) +PD_REGISTER_ACTIVATION_KERNEL_WITH_COMPLEX(asin, AsinKernel) +PD_REGISTER_ACTIVATION_KERNEL_WITH_COMPLEX(atan, AtanKernel) +PD_REGISTER_ACTIVATION_KERNEL_WITH_COMPLEX(sinh, SinhKernel) +PD_REGISTER_ACTIVATION_KERNEL_WITH_COMPLEX(cosh, CoshKernel) +PD_REGISTER_ACTIVATION_KERNEL_WITH_COMPLEX(asinh, AsinhKernel) +PD_REGISTER_ACTIVATION_KERNEL_WITH_COMPLEX(acosh, AcoshKernel) +PD_REGISTER_ACTIVATION_KERNEL_WITH_COMPLEX(atanh, AtanhKernel) PD_REGISTER_ACTIVATION_KERNEL_WITH_COMPLEX(tanh, TanhKernel) PD_REGISTER_ACTIVATION_KERNEL(hardtanh, HardTanhKernel) PD_REGISTER_ACTIVATION_KERNEL(leaky_relu, LeakyReluKernel) @@ -211,7 +211,9 @@ PD_REGISTER_KERNEL(exp, double, int, int64_t, - phi::dtype::float16) {} + phi::dtype::float16, + phi::dtype::complex, + phi::dtype::complex) {} PD_REGISTER_KERNEL(expm1, CPU, @@ -221,7 +223,9 @@ PD_REGISTER_KERNEL(expm1, double, int, int64_t, - phi::dtype::float16) {} + phi::dtype::float16, + phi::dtype::complex, + phi::dtype::complex) {} PD_REGISTER_KERNEL(logit, CPU, ALL_LAYOUT, phi::LogitKernel, float, double) {} PD_REGISTER_KERNEL( diff --git a/paddle/phi/kernels/cpu/compare_kernel.cc b/paddle/phi/kernels/cpu/compare_kernel.cc index ef7987975a12e..24b4615daa58c 100644 --- a/paddle/phi/kernels/cpu/compare_kernel.cc +++ b/paddle/phi/kernels/cpu/compare_kernel.cc @@ -30,22 +30,34 @@ inline void CompareKernelImpl(const Context& ctx, const DenseTensor& y, int axis, DenseTensor* out) { - if (!out->IsSharedWith(x)) { - ctx.template Alloc(out); - if (x.dims().size() >= y.dims().size()) { - funcs::ElementwiseCompute( - ctx, x, y, Functor(), out, axis); - } else { - funcs::ElementwiseCompute( - ctx, x, y, InverseFunctor(), out, axis); - } + ctx.template Alloc(out); + if (x.dims().size() >= y.dims().size()) { + funcs::ElementwiseCompute( + ctx, x, y, Functor(), out, axis); } else { - if (x.dims().size() >= y.dims().size()) { - funcs::ElementwiseCompute(ctx, x, y, Functor(), out, axis); - } else { - funcs::ElementwiseCompute( - ctx, x, y, InverseFunctor(), out, axis); - } + funcs::ElementwiseCompute( + ctx, x, y, InverseFunctor(), out, axis); + } +} + +template +inline void InplaceCompareKernelImpl(const Context& ctx, + const DenseTensor& x, + const DenseTensor& y, + int axis, + DenseTensor* out) { + auto x_origin = x; + out->set_type(phi::DataType::BOOL); + ctx.template Alloc(out); + if (x_origin.dims().size() >= y.dims().size()) { + funcs::ElementwiseCompute( + ctx, x_origin, y, Functor(), out, axis); + } else { + funcs::ElementwiseCompute( + ctx, x_origin, y, InverseFunctor(), out, axis); } } @@ -92,19 +104,21 @@ PD_REGISTER_KERNEL(equal_all, kernel->OutputAt(0).SetDataType(phi::DataType::BOOL); } -#define PD_REGISTER_COMPARE_KERNEL(name, func) \ - PD_REGISTER_KERNEL(name, \ - CPU, \ - ALL_LAYOUT, \ - phi::func##Kernel, \ - bool, \ - int16_t, \ - int, \ - int64_t, \ - float, \ - double, \ - phi::dtype::float16, \ - phi::dtype::bfloat16) {} +#define PD_REGISTER_COMPARE_KERNEL(name, func) \ + PD_REGISTER_KERNEL(name, \ + CPU, \ + ALL_LAYOUT, \ + phi::func##Kernel, \ + bool, \ + int16_t, \ + int, \ + int64_t, \ + float, \ + double, \ + phi::dtype::float16, \ + phi::dtype::bfloat16) { \ + kernel->OutputAt(0).SetDataType(phi::DataType::BOOL); \ + } PD_REGISTER_COMPARE_KERNEL(less_than, LessThan) PD_REGISTER_COMPARE_KERNEL(less_equal, LessEqual) PD_REGISTER_COMPARE_KERNEL(greater_than, GreaterThan) diff --git a/paddle/phi/kernels/cpu/fold_grad_kernel.cc b/paddle/phi/kernels/cpu/fold_grad_kernel.cc index 0c3f1dda03e5e..a56b0aa054571 100644 --- a/paddle/phi/kernels/cpu/fold_grad_kernel.cc +++ b/paddle/phi/kernels/cpu/fold_grad_kernel.cc @@ -18,5 +18,11 @@ #include "paddle/phi/core/kernel_registry.h" #include "paddle/phi/kernels/impl/fold_grad_kernel_impl.h" -PD_REGISTER_KERNEL( - fold_grad, CPU, ALL_LAYOUT, phi::FoldGradKernel, float, double) {} +PD_REGISTER_KERNEL(fold_grad, + CPU, + ALL_LAYOUT, + phi::FoldGradKernel, + float, + double, + phi::dtype::complex, + phi::dtype::complex) {} diff --git a/paddle/phi/kernels/cpu/fold_kernel.cc b/paddle/phi/kernels/cpu/fold_kernel.cc index e22ac4c771ed9..df6cf5652c992 100644 --- a/paddle/phi/kernels/cpu/fold_kernel.cc +++ b/paddle/phi/kernels/cpu/fold_kernel.cc @@ -18,4 +18,11 @@ #include "paddle/phi/core/kernel_registry.h" #include "paddle/phi/kernels/impl/fold_kernel_impl.h" -PD_REGISTER_KERNEL(fold, CPU, ALL_LAYOUT, phi::FoldKernel, float, double) {} +PD_REGISTER_KERNEL(fold, + CPU, + ALL_LAYOUT, + phi::FoldKernel, + float, + double, + phi::dtype::complex, + phi::dtype::complex) {} diff --git a/paddle/phi/kernels/cpu/logical_kernel.cc b/paddle/phi/kernels/cpu/logical_kernel.cc index 06dff8428533f..ef657a161c4e5 100644 --- a/paddle/phi/kernels/cpu/logical_kernel.cc +++ b/paddle/phi/kernels/cpu/logical_kernel.cc @@ -24,20 +24,40 @@ namespace phi { -#define DEFINE_LOGICAL_BINARY_KERNEL(type) \ - template \ - void Logical##type##Kernel(const Context& dev_ctx, \ - const DenseTensor& x, \ - const DenseTensor& y, \ - DenseTensor* out) { \ - funcs::Logical##type##Functor binary_func; \ - if (out->IsSharedWith(x)) { \ - funcs::ElementwiseCompute, T, T>( \ - dev_ctx, x, y, binary_func, out); \ - } else { \ - funcs::ElementwiseCompute, T, bool>( \ - dev_ctx, x, y, binary_func, out); \ - } \ +template +void LogicalKernelImpl(const Context& dev_ctx, + const DenseTensor& x, + const DenseTensor& y, + DenseTensor* out) { + Functor binary_func; + funcs::ElementwiseCompute(dev_ctx, x, y, binary_func, out); +} + +template +void InplaceLogicalKernelImpl(const Context& dev_ctx, + const DenseTensor& x, + const DenseTensor& y, + DenseTensor* out) { + Functor binary_func; + auto x_origin = x; + out->set_type(phi::DataType::BOOL); + funcs::ElementwiseCompute( + dev_ctx, x_origin, y, binary_func, out); +} + +#define DEFINE_LOGICAL_BINARY_KERNEL(type) \ + template \ + void Logical##type##Kernel(const Context& dev_ctx, \ + const DenseTensor& x, \ + const DenseTensor& y, \ + DenseTensor* out) { \ + if (out->IsSharedWith(x)) { \ + InplaceLogicalKernelImpl>( \ + dev_ctx, x, y, out); \ + } else { \ + LogicalKernelImpl>( \ + dev_ctx, x, y, out); \ + } \ } DEFINE_LOGICAL_BINARY_KERNEL(And) @@ -52,15 +72,18 @@ void LogicalNotKernel(const Context& dev_ctx, funcs::LogicalNotFunctor unary_func; phi::Transform trans; - if (!out->IsSharedWith(x)) { + if (out->IsSharedWith(x)) { + auto x_origin = x; + out->set_type(phi::DataType::BOOL); auto* out_ptr = dev_ctx.template Alloc(out); - trans(dev_ctx, x.data(), x.data() + x.numel(), out_ptr, unary_func); - } else { trans(dev_ctx, - x.data(), - x.data() + x.numel(), - reinterpret_cast(out->data()), + x_origin.data(), + x_origin.data() + x_origin.numel(), + out_ptr, unary_func); + } else { + auto* out_ptr = dev_ctx.template Alloc(out); + trans(dev_ctx, x.data(), x.data() + x.numel(), out_ptr, unary_func); } } @@ -79,7 +102,9 @@ void LogicalNotKernel(const Context& dev_ctx, int8_t, \ phi::dtype::complex, \ phi::dtype::complex, \ - int16_t) {} + int16_t) { \ + kernel->OutputAt(0).SetDataType(phi::DataType::BOOL); \ + } REGISTER_LOGICAL_CPU_KERNEL(logical_and, And) REGISTER_LOGICAL_CPU_KERNEL(logical_or, Or) diff --git a/paddle/phi/kernels/funcs/activation_functor.h b/paddle/phi/kernels/funcs/activation_functor.h index 6295ca14aa3ad..6b77c31d38d4a 100644 --- a/paddle/phi/kernels/funcs/activation_functor.h +++ b/paddle/phi/kernels/funcs/activation_functor.h @@ -892,6 +892,22 @@ struct SinhGradFunctor : public BaseActivationFunctor { static constexpr ActBwdOpFwdDeps FwdDeps() { return kDepX; } }; +template +struct SinhGradFunctor> + : public BaseActivationFunctor> { + template + void operator()(Device d, X x, Out out UNUSED, dOut dout, dX dx) const { + dx.device(d) = + dout * x.unaryExpr(Cosh>()).unaryExpr(Conj()); + } + + static constexpr ActBwdOpFwdDeps FwdDeps() { return kDepX; } +}; + // cosh'(x) = sinh(x) template struct CoshGradFunctor : public BaseActivationFunctor { @@ -907,6 +923,22 @@ struct CoshGradFunctor : public BaseActivationFunctor { static constexpr ActBwdOpFwdDeps FwdDeps() { return kDepX; } }; +template +struct CoshGradFunctor> + : public BaseActivationFunctor> { + template + void operator()(Device d, X x, Out out UNUSED, dOut dout, dX dx) const { + dx.device(d) = + dout * x.unaryExpr(Sinh>()).unaryExpr(Conj()); + } + + static constexpr ActBwdOpFwdDeps FwdDeps() { return kDepX; } +}; + template struct Acos { HOSTDEVICE T operator()(const T& val) const { return acos(val); } @@ -944,6 +976,24 @@ struct AcosGradFunctor : public BaseActivationFunctor { static constexpr ActBwdOpFwdDeps FwdDeps() { return kDepX; } }; +template +struct AcosGradFunctor> + : public BaseActivationFunctor> { + template + void operator()(Device d, X x, Out out UNUSED, dOut dout, dX dx) const { + dx.device(d) = + -dout * (static_cast>(1) / + (static_cast>(1) - x.square()).sqrt()) + .unaryExpr(Conj()); + } + + static constexpr ActBwdOpFwdDeps FwdDeps() { return kDepX; } +}; + template struct Asin { HOSTDEVICE T operator()(const T& val) const { return asin(val); } @@ -981,6 +1031,23 @@ struct AsinGradFunctor : public BaseActivationFunctor { static constexpr ActBwdOpFwdDeps FwdDeps() { return kDepX; } }; +template +struct AsinGradFunctor> + : public BaseActivationFunctor> { + template + void operator()(Device d, X x, Out out UNUSED, dOut dout, dX dx) const { + dx.device(d) = dout * (static_cast>(1) / + (static_cast>(1) - x.square()).sqrt()) + .unaryExpr(Conj()); + } + + static constexpr ActBwdOpFwdDeps FwdDeps() { return kDepX; } +}; + template struct Atan { HOSTDEVICE T operator()(const T& val) const { return atan(val); } @@ -1017,6 +1084,23 @@ struct AtanGradFunctor : public BaseActivationFunctor { static constexpr ActBwdOpFwdDeps FwdDeps() { return kDepX; } }; +template +struct AtanGradFunctor> + : public BaseActivationFunctor> { + template + void operator()(Device d, X x, Out out UNUSED, dOut dout, dX dx) const { + dx.device(d) = dout * (static_cast>(1) / + (static_cast>(1) + x.square())) + .unaryExpr(Conj()); + } + + static constexpr ActBwdOpFwdDeps FwdDeps() { return kDepX; } +}; + template struct LogitGradFunctor { template @@ -1066,6 +1150,23 @@ struct AcoshGradFunctor : public BaseActivationFunctor { static constexpr ActBwdOpFwdDeps FwdDeps() { return kDepX; } }; +template +struct AcoshGradFunctor> + : public BaseActivationFunctor> { + template + void operator()(Device d, X x, Out out UNUSED, dOut dout, dX dx) const { + dx.device(d) = + dout * (static_cast>(1) / + (-static_cast>(1) + x.square()).sqrt()) + .unaryExpr(Conj()); + } + + static constexpr ActBwdOpFwdDeps FwdDeps() { return kDepX; } +}; template struct Asinh { HOSTDEVICE T operator()(const T& val) const { return asinh(val); } @@ -1103,6 +1204,23 @@ struct AsinhGradFunctor : public BaseActivationFunctor { static constexpr ActBwdOpFwdDeps FwdDeps() { return kDepX; } }; +template +struct AsinhGradFunctor> + : public BaseActivationFunctor> { + template + void operator()(Device d, X x, Out out UNUSED, dOut dout, dX dx) const { + dx.device(d) = dout * (static_cast>(1) / + (x.square() + static_cast>(1)).sqrt()) + .unaryExpr(Conj()); + } + + static constexpr ActBwdOpFwdDeps FwdDeps() { return kDepX; } +}; + template struct Atanh { HOSTDEVICE T operator()(const T& val) const { return atanh(val); } @@ -1139,6 +1257,22 @@ struct AtanhGradFunctor : public BaseActivationFunctor { static constexpr ActBwdOpFwdDeps FwdDeps() { return kDepX; } }; +template +struct AtanhGradFunctor> + : public BaseActivationFunctor> { + template + void operator()(Device d, X x, Out out UNUSED, dOut dout, dX dx) const { + dx.device(d) = dout * (static_cast>(1) / + (static_cast>(1) - x.square())) + .unaryExpr(Conj()); + } + + static constexpr ActBwdOpFwdDeps FwdDeps() { return kDepX; } +}; // exp functor // exp(x) = e^x template @@ -1167,6 +1301,33 @@ struct ExpGradFunctor : public BaseActivationFunctor { } }; +template +struct ExpGradFunctor> + : public BaseActivationFunctor> { + template + void operator()(Device d, X x UNUSED, Out out, dOut dout, dX dx) const { + dx.device(d) = dout * out.unaryExpr(Conj()); + } + + static constexpr ActBwdOpFwdDeps FwdDeps() { + return ActBwdOpFwdDeps::kDepOut; + } +}; + +template +struct Expm1 {}; + +template +struct Expm1> { + HOSTDEVICE ComplexType operator()(const ComplexType& val) const { + return exp(val) - static_cast>(1); + } +}; + // expm1(x) = e^x - 1 template struct Expm1Functor : public BaseActivationFunctor { @@ -1178,6 +1339,15 @@ struct Expm1Functor : public BaseActivationFunctor { } }; +template +struct Expm1Functor> + : public BaseActivationFunctor> { + template + void operator()(Device d, X x, Out out) const { + out.device(d) = x.unaryExpr(Expm1>()).eval(); + } +}; + template struct Expm1GradFunctor : public BaseActivationFunctor { template { } }; +template +struct Expm1GradFunctor> + : public BaseActivationFunctor> { + template + void operator()(Device d, X x UNUSED, Out out, dOut dout, dX dx) const { + dx.device(d) = dout * out.unaryExpr(Conj()) + dout; + } + + static constexpr ActBwdOpFwdDeps FwdDeps() { return kDepOut; } +}; + // relu(x) = max(x, 0) template struct ReluCPUFunctor : public BaseActivationFunctor { @@ -2831,6 +3016,16 @@ struct CudaExpFunctor : public BaseActivationFunctor { } }; +template +struct CudaExpFunctor> + : public BaseActivationFunctor> { + // exp(x) = exp(x) + __device__ __forceinline__ ComplexType operator()( + const ComplexType x) const { + return static_cast>(exp(x)); + } +}; + template struct CudaSeluFunctor : public BaseActivationFunctor { typename BaseActivationFunctor::AttrPair GetAttrs() { @@ -2907,6 +3102,20 @@ struct CudaExpGradFunctor : public BaseActivationFunctor { } }; +template +struct CudaExpGradFunctor> + : public BaseActivationFunctor> { + // dx = dout * exp(x) + __device__ __forceinline__ ComplexType operator()( + const ComplexType dout, const ComplexType out) const { + return static_cast>(dout * conj(out)); + } + + static constexpr ActBwdOpFwdDeps FwdDeps() { + return ActBwdOpFwdDeps::kDepOut; + } +}; + template struct CudaReciprocalFunctor : public BaseActivationFunctor { using MPType = typename phi::dtype::MPTypeTrait::Type; @@ -2947,6 +3156,15 @@ struct CudaExpm1Functor : public BaseActivationFunctor { } }; +template +struct CudaExpm1Functor> + : public BaseActivationFunctor> { + __device__ __forceinline__ ComplexType operator()( + const ComplexType x) const { + return static_cast>(Expm1>()(x)); + } +}; + template struct CudaExpm1GradFunctor : public BaseActivationFunctor { // dx = dout * out @@ -2959,6 +3177,20 @@ struct CudaExpm1GradFunctor : public BaseActivationFunctor { } }; +template +struct CudaExpm1GradFunctor> + : public BaseActivationFunctor> { + // dx = dout * exp(x) + __device__ __forceinline__ ComplexType operator()( + const ComplexType dout, const ComplexType out) const { + return static_cast>(dout * conj(out) + dout); + } + + static constexpr ActBwdOpFwdDeps FwdDeps() { + return ActBwdOpFwdDeps::kDepOut; + } +}; + template struct CudaSinFunctor : public BaseActivationFunctor { using MPType = typename phi::dtype::MPTypeTrait::Type; @@ -3062,6 +3294,20 @@ struct CudaAsinGradFunctor : public BaseActivationFunctor { static constexpr ActBwdOpFwdDeps FwdDeps() { return ActBwdOpFwdDeps::kDepX; } }; +template +struct CudaAsinGradFunctor> + : public BaseActivationFunctor> { + ComplexType one = static_cast>(1.0f); + + // dx = dout / sqrt(1 - x^2) + __device__ __forceinline__ ComplexType operator()( + const ComplexType dout, const ComplexType x) const { + return static_cast>(dout / conj(sqrt(one - x * x))); + } + + static constexpr ActBwdOpFwdDeps FwdDeps() { return ActBwdOpFwdDeps::kDepX; } +}; + template struct CudaAcosFunctor : public BaseActivationFunctor { using MPType = typename phi::dtype::MPTypeTrait::Type; @@ -3089,6 +3335,20 @@ struct CudaAcosGradFunctor : public BaseActivationFunctor { static constexpr ActBwdOpFwdDeps FwdDeps() { return ActBwdOpFwdDeps::kDepX; } }; +template +struct CudaAcosGradFunctor> + : public BaseActivationFunctor> { + ComplexType one = static_cast>(1.0f); + + // dx = -dout / sqrt(1 - x^2) + __device__ __forceinline__ ComplexType operator()( + const ComplexType dout, const ComplexType x) const { + return static_cast>(-dout / conj(sqrt(one - x * x))); + } + + static constexpr ActBwdOpFwdDeps FwdDeps() { return ActBwdOpFwdDeps::kDepX; } +}; + template struct CudaCoshFunctor : public BaseActivationFunctor { using MPType = typename phi::dtype::MPTypeTrait::Type; @@ -3115,6 +3375,18 @@ struct CudaCoshGradFunctor : public BaseActivationFunctor { static constexpr ActBwdOpFwdDeps FwdDeps() { return ActBwdOpFwdDeps::kDepX; } }; +template +struct CudaCoshGradFunctor> + : public BaseActivationFunctor> { + // dx = dout * sinh(x) + __device__ __forceinline__ ComplexType operator()( + const ComplexType dout, const ComplexType x) const { + return static_cast>(dout * conj(sinh(x))); + } + + static constexpr ActBwdOpFwdDeps FwdDeps() { return ActBwdOpFwdDeps::kDepX; } +}; + template struct CudaSinhFunctor : public BaseActivationFunctor { using MPType = typename phi::dtype::MPTypeTrait::Type; @@ -3141,6 +3413,18 @@ struct CudaSinhGradFunctor : public BaseActivationFunctor { static constexpr ActBwdOpFwdDeps FwdDeps() { return ActBwdOpFwdDeps::kDepX; } }; +template +struct CudaSinhGradFunctor> + : public BaseActivationFunctor> { + // dx = dout * cosh(x) + __device__ __forceinline__ ComplexType operator()( + const ComplexType dout, const ComplexType x) const { + return static_cast>(dout * conj(cosh(x))); + } + + static constexpr ActBwdOpFwdDeps FwdDeps() { return ActBwdOpFwdDeps::kDepX; } +}; + template struct CudaAcoshFunctor : public BaseActivationFunctor { using MPType = typename phi::dtype::MPTypeTrait::Type; @@ -3167,6 +3451,19 @@ struct CudaAcoshGradFunctor : public BaseActivationFunctor { static constexpr ActBwdOpFwdDeps FwdDeps() { return ActBwdOpFwdDeps::kDepX; } }; +template +struct CudaAcoshGradFunctor> + : public BaseActivationFunctor> { + ComplexType one = static_cast>(1.0f); + // dx = dout * 1 / sqrt(x^2 - 1) + __device__ __forceinline__ ComplexType operator()( + const ComplexType dout, const ComplexType x) const { + return static_cast>(dout * conj(one / sqrt(x * x - one))); + } + + static constexpr ActBwdOpFwdDeps FwdDeps() { return ActBwdOpFwdDeps::kDepX; } +}; + template struct CudaAsinhFunctor : public BaseActivationFunctor { using MPType = typename phi::dtype::MPTypeTrait::Type; @@ -3194,6 +3491,20 @@ struct CudaAsinhGradFunctor : public BaseActivationFunctor { static constexpr ActBwdOpFwdDeps FwdDeps() { return ActBwdOpFwdDeps::kDepX; } }; +template +struct CudaAsinhGradFunctor> + : public BaseActivationFunctor> { + ComplexType one = static_cast>(1.0f); + + // dx = dout * 1/sqrt(x^2 + 1) + __device__ __forceinline__ ComplexType operator()( + const ComplexType dout, const ComplexType x) const { + return static_cast>(dout * conj(one / sqrt(x * x + one))); + } + + static constexpr ActBwdOpFwdDeps FwdDeps() { return ActBwdOpFwdDeps::kDepX; } +}; + template struct CudaAtanhFunctor : public BaseActivationFunctor { using MPType = typename phi::dtype::MPTypeTrait::Type; @@ -3310,6 +3621,19 @@ struct CudaAtanhGradFunctor : public BaseActivationFunctor { static constexpr ActBwdOpFwdDeps FwdDeps() { return ActBwdOpFwdDeps::kDepX; } }; +template +struct CudaAtanhGradFunctor> + : public BaseActivationFunctor> { + ComplexType one = static_cast>(1.0f); + // dx = dout * 1/(1- x^2) + __device__ __forceinline__ ComplexType operator()( + const ComplexType dout, const ComplexType x) const { + return static_cast>(dout * conj(one / (one - x * x))); + } + + static constexpr ActBwdOpFwdDeps FwdDeps() { return ActBwdOpFwdDeps::kDepX; } +}; + template struct CudaSqrtFunctor : public BaseActivationFunctor { using MPType = typename phi::dtype::MPTypeTrait::Type; @@ -3387,6 +3711,20 @@ struct CudaAtanGradFunctor : public BaseActivationFunctor { static constexpr ActBwdOpFwdDeps FwdDeps() { return ActBwdOpFwdDeps::kDepX; } }; +template +struct CudaAtanGradFunctor> + : public BaseActivationFunctor> { + ComplexType one = static_cast>(1.0f); + + // dx = dout / (1 + x^2) + __device__ __forceinline__ ComplexType operator()( + const ComplexType dout, const ComplexType x) const { + return dout / conj(one + x * x); + } + + static constexpr ActBwdOpFwdDeps FwdDeps() { return ActBwdOpFwdDeps::kDepX; } +}; + template struct CudaTanhFunctor : public BaseActivationFunctor { using MPType = typename phi::dtype::MPTypeTrait::Type; diff --git a/paddle/phi/kernels/funcs/aligned_vector.h b/paddle/phi/kernels/funcs/aligned_vector.h index c931b90a92a70..558e7dc999cf8 100644 --- a/paddle/phi/kernels/funcs/aligned_vector.h +++ b/paddle/phi/kernels/funcs/aligned_vector.h @@ -13,15 +13,23 @@ See the License for the specific language governing permissions and limitations under the License. */ #pragma once + #include +#include "paddle/phi/core/dense_tensor.h" #include "paddle/phi/core/hostdevice.h" + #if defined(__xpu__) #define CHAR_BIT 8 #endif namespace phi { +template +struct NeedVectorized { + static constexpr bool value = sizeof(T) <= sizeof(float); +}; + // Aligned vector generates vectorized load/store on CUDA. template struct alignas(sizeof(T) * Size) AlignedVector { @@ -53,6 +61,9 @@ HOSTDEVICE inline void Store(const AlignedVector& vec, T* addr) { */ template int GetVectorizedSize(const T* pointer) { + if (!NeedVectorized::value) { + return 1; + } constexpr int max_load_bits = 128; constexpr int valid_vec_size = max_load_bits / CHAR_BIT / sizeof(T); uint64_t address = reinterpret_cast(pointer); @@ -76,4 +87,28 @@ int GetVectorizedSize(const T* pointer) { } } +static int GetVectorizedSize(const DenseTensor* tensor) { + int element_size = phi::SizeOf(tensor->dtype()); + if (element_size > sizeof(float)) { + return 1; + } + constexpr int max_load_bits = 128; + int valid_vec_size = max_load_bits / CHAR_BIT / element_size; + uint64_t address = reinterpret_cast(tensor->data()); + + // Currently, decide to deal with no more than 4 data once while adopting + // vectorization load/store, if performance test shows that dealing with + // 8 data once in vectorization load/store does get optimized, code below + // can begin with : + // if (address % (element_size * 8) == 0) { + // return std::min(8, valid_vec_size); + if (address % (element_size * 4) == 0) { + return std::min(4, valid_vec_size); + } else if (address % (element_size * 2) == 0) { + return std::min(2, valid_vec_size); + } else { + return 1; + } +} + } // namespace phi diff --git a/paddle/phi/kernels/funcs/blas/blaslt_impl.cu.h b/paddle/phi/kernels/funcs/blas/blaslt_impl.cu.h index b46608e91b74a..0fca9de54b2ba 100644 --- a/paddle/phi/kernels/funcs/blas/blaslt_impl.cu.h +++ b/paddle/phi/kernels/funcs/blas/blaslt_impl.cu.h @@ -182,6 +182,8 @@ template cublasComputeType_t GetCudaComputeType() { if (std::is_same::value) { return CUBLAS_COMPUTE_64F; + } else if (std::is_same::value) { + return CUBLAS_COMPUTE_32I; } else { return CUBLAS_COMPUTE_32F; } @@ -206,6 +208,17 @@ struct MatmulDescriptor { is_cached = obj.is_cached; } + MatmulDescriptor& operator=(const MatmulDescriptor& obj) { + algo = obj.algo; + x_desc = obj.x_desc; + y_desc = obj.y_desc; + op_desc = obj.op_desc; + out_desc = obj.out_desc; + is_cached = obj.is_cached; + + return *this; + } + ~MatmulDescriptor() PADDLE_MAY_THROW { if (!is_cached) { PADDLE_WARN_GPU_SUCCESS(dynload::cublasLtMatmulDescDestroy(op_desc)); @@ -237,9 +250,15 @@ struct MatmulDescriptor { bool grad_for_dx = true) { using MT = typename phi::dtype::MPTypeTrait::Type; cudaDataType_t mat_type = phi::backends::gpu::ToCudaDataType(); + cudaDataType_t out_mat_type = phi::backends::gpu::ToCudaDataType(); cudaDataType_t scale_type = phi::backends::gpu::ToCudaDataType(); cublasComputeType_t compute_type = GetCudaComputeType(); + if (std::is_same::value) { + out_mat_type = phi::backends::gpu::ToCudaDataType(); + scale_type = phi::backends::gpu::ToCudaDataType(); + } + // Create operation descriptor; see cublasLtMatmulDescAttributes_t for // details about defaults; just need to set the transforms for A and B PADDLE_ENFORCE_GPU_SUCCESS( @@ -249,7 +268,7 @@ struct MatmulDescriptor { // Create matrix descriptors CreateMatrixLayout(&x_desc, mat_type, M, K, trans_x); CreateMatrixLayout(&y_desc, mat_type, K, N, trans_y); - CreateMatrixLayout(&out_desc, mat_type, M, N, false); + CreateMatrixLayout(&out_desc, out_mat_type, M, N, false); // Config batch size and stride. if (batch_size > 1) { @@ -625,6 +644,197 @@ struct CublasLtBase { } }; +template <> +struct CublasLtBase { + public: + static phi::Allocator::AllocationPtr GetWorkspace(const phi::GPUContext& ctx, + size_t workspace_size) { + return phi::memory_utils::Alloc( + ctx.GetPlace(), + workspace_size, + phi::Stream(reinterpret_cast(ctx.stream()))); + } + + static void RunImpl(const phi::GPUContext& ctx, + MatmulDescriptor* desc, + const size_t sub_key, + const int8_t* x_ptr, + const int8_t* y_ptr, + int32_t* out_ptr, + phi::funcs::MatmulPlanner* planner) { + int32_t alpha = 1; + int32_t beta = + planner->UseAddTo() ? static_cast(1) : static_cast(0); + cublasLtHandle_t cublaslt_handle = ctx.cublaslt_handle(); + + size_t workspace_size = static_cast(4) * 1024 * 1024; + phi::Allocator::AllocationPtr workspace = GetWorkspace(ctx, workspace_size); + + if (planner != nullptr) { + if (phi::autotune::AutoTuneStatus::Instance().UseAutoTune() && + (!desc->is_cached)) { + SearchBestAlgo(ctx, + cublaslt_handle, + desc, + static_cast(&alpha), + static_cast(&beta), + y_ptr, + x_ptr, + out_ptr, + workspace->ptr(), + workspace_size); + MatmulDescriptor* best_desc = new MatmulDescriptor(*desc); + VLOG(6) << best_desc->GetDescResultString( + "[Searched CublasltDescriptor] "); + + auto& cache = phi::autotune::AutoTuneCache::Instance().GetMatmul(); + cache.SetSubKey(sub_key, reinterpret_cast(best_desc)); + } + } + + VLOG(7) << desc->GetDescResultString("[Impl CublasltDescriptor] "); + PADDLE_ENFORCE_GPU_SUCCESS( + dynload::cublasLtMatmul(cublaslt_handle, + desc->op_desc, + static_cast(&alpha), + y_ptr, + desc->y_desc, + x_ptr, + desc->x_desc, + static_cast(&beta), + out_ptr, + desc->out_desc, + out_ptr, + desc->out_desc, + desc->algo, + workspace->ptr(), + workspace_size, + ctx.stream())); + } + + static void SearchBestAlgo(const phi::GPUContext& ctx, + const cublasLtHandle_t& lt_handle, + MatmulDescriptor* desc, + const void* alpha, + const void* beta, + const void* y_data, + const void* x_data, + void* out_data, + void* workspace_ptr, + size_t workspace_size) { + cublasLtMatmulPreference_t preference; + PADDLE_ENFORCE_GPU_SUCCESS( + dynload::cublasLtMatmulPreferenceCreate(&preference)); + PADDLE_ENFORCE_GPU_SUCCESS(dynload::cublasLtMatmulPreferenceSetAttribute( + preference, + CUBLASLT_MATMUL_PREF_MAX_WORKSPACE_BYTES, + &workspace_size, + sizeof(workspace_size))); + + int returned_results = 0; + constexpr int requested_algo_count = 10; + std::vector heuristic_results( + requested_algo_count); + PADDLE_ENFORCE_GPU_SUCCESS( + dynload::cublasLtMatmulAlgoGetHeuristic(lt_handle, + desc->op_desc, + desc->y_desc, + desc->x_desc, + desc->out_desc, + desc->out_desc, + preference, + requested_algo_count, + heuristic_results.data(), + &returned_results)); + PADDLE_ENFORCE_GT(returned_results, + 0, + phi::errors::Unavailable("No GEMM algorithm avaliable.")); + int best_algo_idx = -1; + if (returned_results == 1 || FLAGS_cublaslt_exhaustive_search_times <= 0) { + best_algo_idx = 0; + } else { + float min_time_cost = std::numeric_limits::max(); + for (int algo_idx = 0; algo_idx < returned_results; ++algo_idx) { + float cur_time_cost = + RunAndMeasureAlgo(ctx, + lt_handle, + desc, + alpha, + beta, + y_data, + x_data, + out_data, + workspace_ptr, + workspace_size, + &(heuristic_results[algo_idx].algo)); + VLOG(6) << "[MatmulWithCublaslt] algo[" << algo_idx + << "] time: " << cur_time_cost << " s"; + + if ((best_algo_idx == 0 && (1.05 * cur_time_cost < min_time_cost)) || + (cur_time_cost < min_time_cost)) { + best_algo_idx = algo_idx; + min_time_cost = cur_time_cost; + } + } + } + VLOG(6) << "[MatmulWithCublaslt] best_algo_idx: " << best_algo_idx; + + cublasLtMatmulAlgo_t* best_algo = desc->SetAlgo(); + *best_algo = heuristic_results[best_algo_idx].algo; + PADDLE_ENFORCE_GPU_SUCCESS( + dynload::cublasLtMatmulPreferenceDestroy(preference)); + } + + static float RunAndMeasureAlgo(const phi::GPUContext& ctx, + const cublasLtHandle_t& lt_handle, + MatmulDescriptor* desc, + const void* alpha, + const void* beta, + const void* y_data, + const void* x_data, + void* out_data, + void* workspace_ptr, + size_t workspace_size, + cublasLtMatmulAlgo_t* algo) { + int repeats = FLAGS_cublaslt_exhaustive_search_times; + if (repeats <= 0) { + return std::numeric_limits::max(); + } + + phi::GpuTimer timer; + float time_cost = 0.f; + const auto& stream = ctx.stream(); + + for (int i = 0; i < repeats; ++i) { + timer.Start(stream); + PADDLE_ENFORCE_GPU_SUCCESS(dynload::cublasLtMatmul(lt_handle, + desc->op_desc, + alpha, + y_data, + desc->y_desc, + x_data, + desc->x_desc, + beta, + out_data, + desc->out_desc, + out_data, + desc->out_desc, + algo, + workspace_ptr, + workspace_size, + stream)); + timer.Stop(stream); + ctx.Wait(); + auto time = timer.ElapsedTime(); + if (i > 0) { + // Exclude the warmup runtime. + time_cost += time; + } + } + return (time_cost / (repeats - 1)); + } +}; + // To judge if desc is cached or not. template ::value) { + if (!trans_x && !trans_y) { + PADDLE_ENFORCE_EQ( + (N % 4 == 0 || N == 1), + true, + phi::errors::InvalidArgument( + "The dimension size N used in int8 matmul must be 1 or a " + "multiple of 4 does not " + "match the size (%d) currently contained in the container.", + N)); + PADDLE_ENFORCE_EQ( + (K % 4 == 0), + true, + phi::errors::InvalidArgument( + "The dimension size K used in int8 matmul must be a multiple " + "of 4 does not " + "match the size (%d) currently contained in the container.", + K)); + } else if (!trans_x && trans_y) { + PADDLE_ENFORCE_EQ( + (K % 4 == 0), + true, + phi::errors::InvalidArgument( + "The dimension size K used in int8 matmul must be a multiple " + "of 4 does not " + "match the size (%d) currently contained in the container.", + K)); + } else if (trans_x && !trans_y) { + PADDLE_ENFORCE_EQ( + (M % 4 == 0 || M == 1), + true, + phi::errors::InvalidArgument( + "The dimension size M used in int8 matmul must be 1 or a " + "multiple of 4 does not " + "match the size (%d) currently contained in the container.", + M)); + PADDLE_ENFORCE_EQ( + (N % 4 == 0 || N == 1), + true, + phi::errors::InvalidArgument( + "The dimension size N used in int8 matmul must be 1 or a " + "multiple of 4 does not " + "match the size (%d) currently contained in the container.", + N)); + } else { + PADDLE_ENFORCE_EQ( + (M % 4 == 0 || M == 1), + true, + phi::errors::InvalidArgument( + "The dimension size M used in int8 matmul must be 1 or a " + "multiple of 4 does not " + "match the size (%d) currently contained in the container.", + M)); + PADDLE_ENFORCE_EQ( + (K % 4 == 0), + true, + phi::errors::InvalidArgument( + "The dimension size K used in int8 matmul must be a multiple " + "of 4 does not " + "match the size (%d) currently contained in the container.", + K)); + } + } + if (planner != nullptr) { sub_key = planner->GenSubKey(); } @@ -680,13 +954,13 @@ struct DescriptorSetter { }; // For matmul with kernels autotune -template -struct MatmulWithCublasLt : public CublasLtBase { +template +struct MatmulWithCublasLt : public CublasLtBase { public: static void Run(const phi::GPUContext& ctx, const T* x_data, const T* y_data, - T* out_data, + OutT* out_data, const int64_t M, const int64_t N, const int64_t K, @@ -695,14 +969,14 @@ struct MatmulWithCublasLt : public CublasLtBase { phi::funcs::MatmulPlanner* planner = nullptr) { auto setter = DescriptorSetter( planner, M, N, K, trans_x, trans_y); - CublasLtBase::RunImpl( + CublasLtBase::RunImpl( ctx, &setter.desc, setter.sub_key, x_data, y_data, out_data, planner); } static void RunWithBatch(const phi::GPUContext& ctx, const T* x_data, const T* y_data, - T* out_data, + OutT* out_data, const int64_t M, const int64_t N, const int64_t K, @@ -723,14 +997,14 @@ struct MatmulWithCublasLt : public CublasLtBase { stride_x, stride_y, stride_out); - CublasLtBase::RunImpl( + CublasLtBase::RunImpl( ctx, &setter.desc, setter.sub_key, x_data, y_data, out_data, planner); } static void RunWithBatch(const phi::GPUContext& ctx, const T** x_data, const T** y_data, - T** out_data, + OutT** out_data, const int64_t M, const int64_t N, const int64_t K, diff --git a/paddle/phi/kernels/funcs/broadcast_function.h b/paddle/phi/kernels/funcs/broadcast_function.h index e754ce3bf49e4..2ba3271d2c7df 100644 --- a/paddle/phi/kernels/funcs/broadcast_function.h +++ b/paddle/phi/kernels/funcs/broadcast_function.h @@ -29,74 +29,86 @@ namespace funcs { #if defined(__NVCC__) || defined(__HIPCC__) || defined(__xpu__) -enum BroadcastLoadType { kMixed = 1, kBroadcast = 2, kElementwise = 3 }; - -template -struct UseBroadcast { - template - static HOSTDEVICE void Apply( - const std::vector &ins_tensor, - const ArgsT &args, - int64_t numel, - Array1 *ins_data, - Array2 *use_broadcast, - int *broadcast_num, - bool *all_elementwise) { - (*ins_data)[Index] = (const _ptr_ char *)(ins_tensor[Index]->data()); - bool is_same_dim = ins_tensor[Index]->numel() == numel; - if (is_same_dim) { - (*use_broadcast)[Index] = false; - } else { - (*use_broadcast)[Index] = true; - (*broadcast_num)++; - } - *all_elementwise &= is_same_dim; - } -}; +enum BroadcastType { kMixed = 1, kBroadcast = 2, kElementwise = 3 }; -template -struct LoaderTypeClassifier { - public: +template +struct BroadcastTypeClassifier { int64_t numel{0}; - int vec_size{4}; - int broadcast_num{0}; - bool all_elementwise{true}; - phi::Array use_broadcast; + int broadcast_num{0}; // Not used for XPU + bool all_elementwise{true}; // Not used for XPU + phi::Array use_broadcast; // Not used for XPU + phi::Array configs; phi::Array ins_data; + phi::Array<_ptr_ OutT *, NumOuts> outs_data; + + BroadcastTypeClassifier() {} + BroadcastTypeClassifier(const std::vector &ins, + std::vector *outs, + int axis) { + numel = (*outs)[0]->numel(); + +#ifndef PADDLE_WITH_XPU_KP + for (size_t i = 0; i < ins.size(); ++i) { + bool is_same_dim = ins[i]->numel() == numel; + if (is_same_dim) { + use_broadcast[i] = false; + } else { + use_broadcast[i] = true; + broadcast_num++; + } + all_elementwise &= is_same_dim; + } +#endif + + InitBroadcastConfigs(ins, outs, axis); - LoaderTypeClassifier() {} - LoaderTypeClassifier(const std::vector &ins, - std::vector *outs) { using Traits = phi::funcs::FunctionTraits; using ArgsT = typename Traits::ArgsTuple; ArgsT arg; - uint64_t out_addr = reinterpret_cast((*outs)[0]->data()); - - UnrollerWithoutVecSize::step(ins, arg, &vec_size); - - for (auto i = 1; i < outs->size(); ++i) { - PADDLE_ENFORCE_EQ( - (*outs)[i]->dims(), - (*outs)[0]->dims(), - phi::errors::InvalidArgument( - "The shape of each output tensor shall be identical yet, but " - "%d-th output tensor`s shape is not.", - i)); - out_addr = - (out_addr | reinterpret_cast((*outs)[i]->data())); + UnrollerWithoutVecSize::step(ins, arg, &ins_data); + for (int i = 0; i < NumOuts; ++i) { + outs_data[i] = (*outs)[i]->data(); } + } - vec_size = std::min( - vec_size, - phi::GetVectorizedSize(reinterpret_cast(out_addr))); - numel = (*outs)[0]->numel(); - UnrollerWithoutVecSize::step(ins, - arg, - numel, - &ins_data, - &use_broadcast, - &broadcast_num, - &all_elementwise); + void InitBroadcastConfigs(const std::vector &ins, + std::vector *outs, + int axis) { +#ifdef PADDLE_WITH_XPU_KP + const auto dims_simplifier = + BroadcastDimsSimplifier(ins, (*outs)[0]->dims(), axis); + if (VLOG_IS_ON(6)) { + DimsSimplifiedLogger::Log( + ins, outs, dims_simplifier, "BroadcastKernel"); + } + configs[0] = kps::details::BroadcastConfig(dims_simplifier.out_dims, + dims_simplifier.in_dims[0], + dims_simplifier.in_dims[1], + dims_simplifier.rank); + configs[1] = kps::details::BroadcastConfig(dims_simplifier.out_dims, + dims_simplifier.in_dims[1], + dims_simplifier.in_dims[0], + dims_simplifier.rank); +#else + if (!all_elementwise) { + const auto dims_simplifier = + BroadcastDimsSimplifier(ins, (*outs)[0]->dims(), axis); + if (VLOG_IS_ON(6)) { + DimsSimplifiedLogger::Log( + ins, outs, dims_simplifier, "BroadcastKernel"); + } + for (int i = 0; i < Arity; ++i) { + // if data shape is[m, n], then you should set data_dim = {n, m} + // eg: out's shape [3, 45, 1]. then out_dims = {1, 45, 3} + // if (ins[i]->numel() != (*outs)[0]->numel()) { + if (ins[i]->numel()) { + configs[i] = kps::details::BroadcastConfig(dims_simplifier.out_dims, + dims_simplifier.in_dims[i], + dims_simplifier.rank); + } + } + } +#endif } }; @@ -425,18 +437,10 @@ __global__ void VectorizedBroadcastKernel( template void LaunchBroadcastKernel( const KPDevice &ctx, - const std::vector &ins, - std::vector *outs, - Functor func, - const phi::Array &configs, - const LoaderTypeClassifier &loader_classifier) { - phi::Array<_ptr_ OutT *, NumOuts> outs_data; - for (int i = 0; i < NumOuts; ++i) { - outs_data[i] = (_ptr_ OutT *)(ctx.Alloc((*outs)[i])); - } - + const BroadcastTypeClassifier &classifier, + Functor func) { #ifdef PADDLE_WITH_XPU_KP - int numel = (*outs)[0]->numel(); + int numel = classifier.numel; const int threads = 64; const int blocks = 8; int read_lens = configs[0].buf_len; @@ -445,17 +449,17 @@ void LaunchBroadcastKernel( int tail_tid = numel % (read_lens * threads); VectorizedBroadcastKernel - <<>>(loader_classifier.ins_data, - outs_data, - loader_classifier.use_broadcast, + <<>>(classifier.ins_data, + classifier.outs_data, + classifier.use_broadcast, numel, - configs, + classifier.configs, main_offset, tail_tid, read_lens, func); #else - const auto &numel = loader_classifier.numel; + const auto &numel = classifier.numel; auto gpu_config = phi::backends::gpu::GetGpuLaunchConfig1D(ctx, numel, VecSize); auto stream = ctx.stream(); @@ -464,41 +468,41 @@ void LaunchBroadcastKernel( int main_offset = (numel / (VecSize * threads)) * VecSize * threads; int tail_tid = numel % (VecSize * threads); - if (loader_classifier.all_elementwise) { + if (classifier.all_elementwise) { VectorizedBroadcastKernel - <<>>(loader_classifier.ins_data, - outs_data, - loader_classifier.use_broadcast, + <<>>(classifier.ins_data, + classifier.outs_data, + classifier.use_broadcast, numel, - configs, + classifier.configs, main_offset, tail_tid, VecSize, func); - } else if (loader_classifier.broadcast_num > (Arity >> 1)) { - constexpr BroadcastLoadType type_ = (Arity > 1) ? kBroadcast : kMixed; + } else if (classifier.broadcast_num > (Arity >> 1)) { + constexpr BroadcastType type_ = (Arity > 1) ? kBroadcast : kMixed; VectorizedBroadcastKernel - <<>>(loader_classifier.ins_data, - outs_data, - loader_classifier.use_broadcast, + <<>>(classifier.ins_data, + classifier.outs_data, + classifier.use_broadcast, numel, - configs, + classifier.configs, main_offset, tail_tid, VecSize, func); } else { VectorizedBroadcastKernel - <<>>(loader_classifier.ins_data, - outs_data, - loader_classifier.use_broadcast, + <<>>(classifier.ins_data, + classifier.outs_data, + classifier.use_broadcast, numel, - configs, + classifier.configs, main_offset, tail_tid, VecSize, @@ -632,9 +636,13 @@ struct LaunchBroadcastKernelWithInt64IndexHelper *outs, int axis, Functor functor) { + using Traits = phi::funcs::FunctionTraits; + using ArgsT = typename Traits::ArgsTuple; + ArgsT arg; phi::Array::kValue> ins_ptrs; - UnrollerWithoutVecSize::step(ins, &ins_ptrs); + UnrollerWithoutVecSize::step(ins, arg, &ins_ptrs); + auto *out_tensor = (*outs)[0]; auto *out_ptr = ctx.Alloc(out_tensor); @@ -815,26 +823,65 @@ struct LaunchBroadcastKernelWithInt64IndexHelper -void BroadcastKernelForDifferentVecSize( - const KPDevice &ctx, - const std::vector &ins, - std::vector *outs, - int axis, - Functor func) { +template +typename std::enable_if::value, void>::type +BroadcastKernelForDifferentVecSize(const KPDevice &ctx, + const std::vector &ins, + std::vector *outs, + int axis, + Functor func) { #ifndef PADDLE_WITH_XPU_KP - constexpr bool kEnabledInt64IndexKernel = (NumOuts == 1 && kArity <= 3); + constexpr bool kEnabledInt64IndexKernel = (NumOuts == 1 && Arity <= 3); bool use_int64_index_kernel = kEnabledInt64IndexKernel && (*outs)[0]->numel() >= std::numeric_limits::max(); if (use_int64_index_kernel) { - auto loader_classifier = - LoaderTypeClassifier(ins, outs); - switch (loader_classifier.vec_size) { + LaunchBroadcastKernelWithInt64IndexHelper::Run(ctx, + ins, + outs, + axis, + func); + return; + } +#endif + + auto classifier = + BroadcastTypeClassifier(ins, outs, axis); + LaunchBroadcastKernel( + ctx, classifier, func); +} + +template +typename std::enable_if::value, void>::type +BroadcastKernelForDifferentVecSize(const KPDevice &ctx, + const std::vector &ins, + std::vector *outs, + int axis, + Functor func) { +#ifdef PADDLE_WITH_XPU_KP + auto type = kps::details::OptType::CanNotOptimize; + bool is_optimize = classifier.configs[0].cmp_type != type; + int vec_size = is_optimize ? VecSizeL : VecSizeM; +#else + // Calculate the max vec_size for all ins and outs. + int vec_size = GetVectorizedSizeForTensors(ins, *outs); +#endif + +#ifndef PADDLE_WITH_XPU_KP + constexpr bool kEnabledInt64IndexKernel = (NumOuts == 1 && Arity <= 3); + bool use_int64_index_kernel = + kEnabledInt64IndexKernel && + (*outs)[0]->numel() >= std::numeric_limits::max(); + if (use_int64_index_kernel) { + switch (vec_size) { case VecSizeL: { LaunchBroadcastKernelWithInt64IndexHelper::Run(ctx, ins, @@ -846,7 +893,7 @@ void BroadcastKernelForDifferentVecSize( case VecSizeM: { LaunchBroadcastKernelWithInt64IndexHelper::Run(ctx, ins, @@ -858,7 +905,7 @@ void BroadcastKernelForDifferentVecSize( case VecSizeS: { LaunchBroadcastKernelWithInt64IndexHelper::Run(ctx, ins, @@ -869,7 +916,7 @@ void BroadcastKernelForDifferentVecSize( } default: { PADDLE_THROW(phi::errors::Unimplemented( - "Unsupported vectorized size: %d!", loader_classifier.vec_size)); + "Unsupported vectorized size: %d!", vec_size)); break; } } @@ -877,74 +924,27 @@ void BroadcastKernelForDifferentVecSize( } #endif - phi::Array configs; -#ifdef PADDLE_WITH_XPU_KP - PADDLE_ENFORCE_EQ( - ins.size(), - 2, - phi::errors::InvalidArgument( - "XPU only support inputs is 2, but received %d", ins.size())); - - auto loader_classifier = LoaderTypeClassifier(); - const auto dims_simplifier = - BroadcastDimsSimplifier(ins, (*outs)[0]->dims(), axis); - if (VLOG_IS_ON(6)) { - DimsSimplifiedLogger::Log( - ins, outs, dims_simplifier, "XPU Broadcast"); - } - configs[0] = kps::details::BroadcastConfig(dims_simplifier.out_dims, - dims_simplifier.in_dims[0], - dims_simplifier.in_dims[1], - dims_simplifier.rank); - configs[1] = kps::details::BroadcastConfig(dims_simplifier.out_dims, - dims_simplifier.in_dims[1], - dims_simplifier.in_dims[0], - dims_simplifier.rank); - auto type = kps::details::OptType::CanNotOptimize; - bool is_optimize = configs[0].cmp_type != type; - int vec_size = is_optimize ? VecSizeL : VecSizeM; -#else - auto loader_classifier = - LoaderTypeClassifier(ins, outs); - if (!loader_classifier.all_elementwise) { - const auto dims_simplifier = - BroadcastDimsSimplifier(ins, (*outs)[0]->dims(), axis); - - if (VLOG_IS_ON(6)) { - DimsSimplifiedLogger::Log( - ins, outs, dims_simplifier, "GPU Broadcast"); - } - for (int i = 0; i < kArity; ++i) { - // if data shape is[m, n], then you should set data_dim = {n, m} - // eg: out's shape [3, 45, 1]. then out_dims = {1, 45, 3} - // if (ins[i]->numel() != (*outs)[0]->numel()) { - if (ins[i]->numel()) { - configs[i] = kps::details::BroadcastConfig(dims_simplifier.out_dims, - dims_simplifier.in_dims[i], - dims_simplifier.rank); - } - } - } -#endif - switch (loader_classifier.vec_size) { + auto classifier = + BroadcastTypeClassifier(ins, outs, axis); + switch (vec_size) { case VecSizeL: { - LaunchBroadcastKernel( - ctx, ins, outs, func, configs, loader_classifier); + LaunchBroadcastKernel( + ctx, classifier, func); break; } case VecSizeM: { - LaunchBroadcastKernel( - ctx, ins, outs, func, configs, loader_classifier); + LaunchBroadcastKernel( + ctx, classifier, func); break; } case VecSizeS: { - LaunchBroadcastKernel( - ctx, ins, outs, func, configs, loader_classifier); + LaunchBroadcastKernel( + ctx, classifier, func); break; } default: { PADDLE_THROW(phi::errors::Unimplemented( - "Unsupported vectorized size: %d!", loader_classifier.vec_size)); + "Unsupported vectorized size: %d!", vec_size)); break; } } @@ -960,6 +960,15 @@ void BroadcastKernel(const KPDevice &ctx, // maximum rank of all inputs. using Traits = phi::funcs::FunctionTraits; const int kArity = Traits::arity; + +#ifdef PADDLE_WITH_XPU_KP + PADDLE_ENFORCE_EQ( + ins.size(), + 2, + phi::errors::InvalidArgument( + "XPU only support inputs is 2, but received %d", ins.size())); +#endif + PADDLE_ENFORCE_EQ( ins.size(), kArity, @@ -980,6 +989,19 @@ void BroadcastKernel(const KPDevice &ctx, outs->size(), NumOuts)); + for (auto i = 0; i < outs->size(); ++i) { + if (i > 0) { + PADDLE_ENFORCE_EQ( + (*outs)[i]->dims(), + (*outs)[0]->dims(), + phi::errors::InvalidArgument( + "The shape of each output tensor shall be identical yet, but " + "%d-th output tensor`s shape is not.", + i)); + } + ctx.template Alloc((*outs)[i]); + } + int max_rank = 0; int min_rank = phi::DDim::kMaxRank; for (auto *in : ins) { diff --git a/paddle/phi/kernels/funcs/elementwise_base.h b/paddle/phi/kernels/funcs/elementwise_base.h index 274ac1cc32c05..8ddb3f406ddfe 100644 --- a/paddle/phi/kernels/funcs/elementwise_base.h +++ b/paddle/phi/kernels/funcs/elementwise_base.h @@ -553,43 +553,28 @@ struct Loader { template struct InputSetter { - template - static HOSTDEVICE void Apply( - const std::vector &ins_tensor, Array *ins_data) { - (*ins_data)[Index] = (const _ptr_ char *)(ins_tensor[Index]->data()); - } -}; - -template -struct VecSizeGetter { - template - static HOSTDEVICE void Apply(const std::vector &ins, - const ArgsT &args, - int *vec_size) { + template + static void Apply(const std::vector &ins_tensor, + const ArgsT &args, + Array *ins_data) { using Type = std::tuple_element_t; - *vec_size = std::min(*vec_size, - phi::GetVectorizedSize(ins[Index]->data())); + (*ins_data)[Index] = (const _ptr_ char *)(ins_tensor[Index]->data()); } }; -template -int GetVectorizedSizeForTensors(const std::vector &ins, - const std::vector &outs) { +static int GetVectorizedSizeForTensors( + const std::vector &ins, + const std::vector &outs) { #ifdef PADDLE_WITH_XPU_KP int vec_size = 256; #else - using Traits = phi::funcs::FunctionTraits; - using ArgsT = typename Traits::ArgsTuple; - const int Arity = Traits::arity; int vec_size = 4; - uint64_t addr = static_cast(0); - ArgsT arg; - UnrollerWithoutVecSize::step(ins, arg, &vec_size); - for (auto iter = outs.begin(); iter != outs.end(); ++iter) { - addr = (addr | reinterpret_cast((*iter)->data())); + for (size_t i = 0; i < ins.size(); ++i) { + vec_size = std::min(vec_size, phi::GetVectorizedSize(ins[i])); + } + for (size_t i = 0; i < outs.size(); ++i) { + vec_size = std::min(vec_size, phi::GetVectorizedSize(outs[i])); } - vec_size = std::min( - vec_size, phi::GetVectorizedSize(reinterpret_cast(addr))); #endif return vec_size; } @@ -738,10 +723,10 @@ __global__ void VectorizedElementwiseKernel( } template -void LaunchElementwiseCudaKernel(const KPDevice &ctx, - const std::vector &ins, - std::vector *outs, - Functor func) { +void LaunchElementwiseKernel(const KPDevice &ctx, + const std::vector &ins, + std::vector *outs, + Functor func) { // There are at least 1 output, but maybe 0 input (ins.size() == 0). // For large tensor numel * sizeof(T) > 2^31, we must use int64_t as index // type. @@ -749,10 +734,14 @@ void LaunchElementwiseCudaKernel(const KPDevice &ctx, phi::Array ins_data; phi::Array<_ptr_ OutT *, NumOuts> outs_data; - UnrollerWithoutVecSize::step(ins, &ins_data); - for (int i = 0; i < NumOuts; ++i) { - outs_data[i] = (_ptr_ OutT *)(ctx.Alloc((*outs)[i])); + using Traits = phi::funcs::FunctionTraits; + using ArgsT = typename Traits::ArgsTuple; + ArgsT arg; + UnrollerWithoutVecSize::step(ins, arg, &ins_data); + for (int i = 0; i < outs->size(); ++i) { + outs_data[i] = (*outs)[i]->data(); } + #ifdef PADDLE_WITH_XPU_KP int block_size = 64; int grid_size = 8; @@ -775,6 +764,47 @@ void LaunchElementwiseCudaKernel(const KPDevice &ctx, #endif } +template +typename std::enable_if::value, void>::type +ElementwiseKernelForDifferentVecSize( + const KPDevice &ctx, + const std::vector &ins, + std::vector *outs, + Functor func) { + LaunchElementwiseKernel( + ctx, ins, outs, func); +} + +template +typename std::enable_if::value, void>::type +ElementwiseKernelForDifferentVecSize( + const KPDevice &ctx, + const std::vector &ins, + std::vector *outs, + Functor func) { + // calculate the max vec_size for all ins and outs + int vec_size = GetVectorizedSizeForTensors(ins, *outs); + switch (vec_size) { + case VecSizeL: + LaunchElementwiseKernel( + ctx, ins, outs, func); + break; + case VecSizeM: + LaunchElementwiseKernel( + ctx, ins, outs, func); + break; + case VecSizeS: + LaunchElementwiseKernel( + ctx, ins, outs, func); + break; + default: { + PADDLE_THROW(phi::errors::Unimplemented( + "Unsupported vectorized size: %d !", vec_size)); + break; + } + } +} + template void ElementwiseKernel(const KPDevice &ctx, const std::vector &ins, @@ -798,8 +828,8 @@ void ElementwiseKernel(const KPDevice &ctx, outs->size(), NumOuts)); - if (NumOuts > 1) { - for (int i = 1; i < NumOuts; ++i) { + for (int i = 0; i < outs->size(); ++i) { + if (i > 0) { PADDLE_ENFORCE_EQ( (*outs)[i]->dims(), (*outs)[0]->dims(), @@ -808,29 +838,11 @@ void ElementwiseKernel(const KPDevice &ctx, "but %dth output tensor`s shape is not.", i)); } + ctx.template Alloc((*outs)[i]); } - // calculate the max vec_size for all ins and outs - int vec_size = GetVectorizedSizeForTensors(ins, *outs); - switch (vec_size) { - case VecSizeL: - LaunchElementwiseCudaKernel( - ctx, ins, outs, func); - break; - case VecSizeM: - LaunchElementwiseCudaKernel( - ctx, ins, outs, func); - break; - case VecSizeS: - LaunchElementwiseCudaKernel( - ctx, ins, outs, func); - break; - default: { - PADDLE_THROW(phi::errors::Unimplemented( - "Unsupported vectorized size: %d !", vec_size)); - break; - } - } + ElementwiseKernelForDifferentVecSize( + ctx, ins, outs, func); } #endif diff --git a/paddle/phi/kernels/funcs/elementwise_functor.h b/paddle/phi/kernels/funcs/elementwise_functor.h index b7994d9cefa51..3f221d98b0a25 100644 --- a/paddle/phi/kernels/funcs/elementwise_functor.h +++ b/paddle/phi/kernels/funcs/elementwise_functor.h @@ -40,22 +40,11 @@ struct AddFunctor { inline HOSTDEVICE T operator()(const T a, const T b) const { return a + b; } }; template -struct InverseAddFunctor { - inline HOSTDEVICE T operator()(const T a, const T b) const { return b + a; } -}; - -// Float32Bfloat16Add -template -struct Float32Bfloat16AddFunctor { - inline HOSTDEVICE T operator()(const T x, const phi::bfloat16 y) { - return x + static_cast(y); - } -}; +using InverseAddFunctor = AddFunctor; -// Float32Float16Add -template -struct Float32Float16AddFunctor { - inline HOSTDEVICE T operator()(const T x, const phi::float16 y) { +template +struct MultiPrecisionAddFunctor { + inline HOSTDEVICE T operator()(const T x, const Ty y) const { return x + static_cast(y); } }; @@ -82,15 +71,7 @@ struct MultiplyFunctor { } }; template -struct InverseMultiplyFunctor { - inline HOSTDEVICE T operator()(const T a, const T b) const { return b * a; } -}; -template <> -struct InverseMultiplyFunctor { - inline HOSTDEVICE bool operator()(const bool a, const bool b) const { - return b && a; - } -}; +using InverseMultiplyFunctor = MultiplyFunctor; template struct IsZeroFunctor { diff --git a/paddle/phi/kernels/funcs/im2col.cc b/paddle/phi/kernels/funcs/im2col.cc index 0b5901367488a..e4c470e1a7064 100644 --- a/paddle/phi/kernels/funcs/im2col.cc +++ b/paddle/phi/kernels/funcs/im2col.cc @@ -160,12 +160,24 @@ template class Im2ColFunctor; +template class Im2ColFunctor>; +template class Im2ColFunctor>; template class Col2ImFunctor; template class Col2ImFunctor; +template class Col2ImFunctor>; +template class Col2ImFunctor>; /* * im = [input_channels, input_height, input_width] @@ -331,11 +343,23 @@ template class Im2ColFunctor; +template class Im2ColFunctor>; +template class Im2ColFunctor>; template class Col2ImFunctor; template class Col2ImFunctor; +template class Col2ImFunctor>; +template class Col2ImFunctor>; } // namespace funcs } // namespace phi diff --git a/paddle/phi/kernels/funcs/im2col.cu b/paddle/phi/kernels/funcs/im2col.cu index 87c82adbb7fbe..b633241810f9b 100644 --- a/paddle/phi/kernels/funcs/im2col.cu +++ b/paddle/phi/kernels/funcs/im2col.cu @@ -310,6 +310,12 @@ template class Im2ColFunctor; +template class Im2ColFunctor>; +template class Im2ColFunctor>; template class Im2ColFunctor; @@ -322,6 +328,12 @@ template class Col2ImFunctor; +template class Col2ImFunctor>; +template class Col2ImFunctor>; template class Col2ImFunctor; @@ -573,6 +585,12 @@ template class Im2ColFunctor; +template class Im2ColFunctor>; +template class Im2ColFunctor>; template class Im2ColFunctor; @@ -585,6 +603,12 @@ template class Col2ImFunctor; +template class Col2ImFunctor>; +template class Col2ImFunctor>; template class Col2ImFunctor; diff --git a/paddle/phi/kernels/funcs/weight_only_gemv.cu b/paddle/phi/kernels/funcs/weight_only_gemv.cu index a1c746bd49ce1..76716ecf30dc5 100644 --- a/paddle/phi/kernels/funcs/weight_only_gemv.cu +++ b/paddle/phi/kernels/funcs/weight_only_gemv.cu @@ -189,7 +189,12 @@ struct ConvertDstFunc<__nv_bfloat16> { template struct HalfMul { static __device__ __forceinline__ T apply(const T& x, const T& y) { +#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 530 return __hmul(x, y); +#else + float res = static_cast(float16(x)) * static_cast(float16(y)); + return float16(res).to_half(); +#endif } }; diff --git a/paddle/phi/kernels/fusion/cutlass/cutlass_kernels/fpA_intB_gemm/fpA_intB_gemm_template.cu b/paddle/phi/kernels/fusion/cutlass/cutlass_kernels/fpA_intB_gemm/fpA_intB_gemm_template.cu new file mode 100644 index 0000000000000..ff98cb01c9866 --- /dev/null +++ b/paddle/phi/kernels/fusion/cutlass/cutlass_kernels/fpA_intB_gemm/fpA_intB_gemm_template.cu @@ -0,0 +1,590 @@ +/* + * Copyright (c) 2020-2023, NVIDIA CORPORATION. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +/* Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include "paddle/phi/kernels/fusion/cutlass/cutlass_kernels/fpA_intB_gemm/fpA_intB_gemm_template.h" + +#pragma GCC diagnostic push +#pragma GCC diagnostic ignored "-Wstrict-aliasing" +#pragma GCC diagnostic pop + +namespace phi { + +template +void dispatch_gemm_config(const T* A, + const WeightType* B, + const float* weight_scales, + const T* biases, + T* C, + int m, + int n, + int k, + CutlassGemmConfig gemm_config, + char* workspace, + size_t workspace_bytes, + cudaStream_t stream, + int* occupancy) { + switch (gemm_config.stages) { + case 2: + using DispatcherStages2 = dispatch_stages; + DispatcherStages2::dispatch(A, + B, + weight_scales, + biases, + C, + m, + n, + k, + gemm_config, + workspace, + workspace_bytes, + stream, + occupancy); + break; + case 3: + using DispatcherStages3 = dispatch_stages; + DispatcherStages3::dispatch(A, + B, + weight_scales, + biases, + C, + m, + n, + k, + gemm_config, + workspace, + workspace_bytes, + stream, + occupancy); + break; + case 4: + using DispatcherStages4 = dispatch_stages; + DispatcherStages4::dispatch(A, + B, + weight_scales, + biases, + C, + m, + n, + k, + gemm_config, + workspace, + workspace_bytes, + stream, + occupancy); + break; + default: + std::string err_msg = "dispatch_gemm_config does not support stages " + + std::to_string(gemm_config.stages); + throw std::runtime_error("[dispatch_gemm_config] " + err_msg); + break; + } +} + +template +void dispatch_gemm_to_cutlass(const T* A, + const WeightType* B, + const float* weight_scales, + const T* biases, + T* C, + int m, + int n, + int k, + char* workspace, + size_t workspace_bytes, + CutlassGemmConfig gemm_config, + cudaStream_t stream, + int* occupancy) { + // VLOG(3)<<__PRETTY_FUNCTION__; + // Note that SIMT configs are omitted here since they are not supported for + // fpA_intB. We also only instantiate configs here where threadblockShapeM == + // warpShapeM since those usually perform the best for mixed type gemms. + switch (gemm_config.tile_config) { + case CutlassTileConfig::CtaShape32x128x64_WarpShape32x32x64: + dispatch_gemm_config, + cutlass::gemm::GemmShape<32, 32, 64>>( + A, + B, + weight_scales, + biases, + C, + m, + n, + k, + gemm_config, + workspace, + workspace_bytes, + stream, + occupancy); + break; + case CutlassTileConfig::CtaShape64x128x64_WarpShape64x32x64: + dispatch_gemm_config, + cutlass::gemm::GemmShape<64, 32, 64>>( + A, + B, + weight_scales, + biases, + C, + m, + n, + k, + gemm_config, + workspace, + workspace_bytes, + stream, + occupancy); + break; + case CutlassTileConfig::CtaShape128x128x64_WarpShape128x32x64: + dispatch_gemm_config, + cutlass::gemm::GemmShape<128, 32, 64>>( + A, + B, + weight_scales, + biases, + C, + m, + n, + k, + gemm_config, + workspace, + workspace_bytes, + stream, + occupancy); + break; + // config for M_16000_N_12288_K_6144 in encoder + case CutlassTileConfig::CtaShape256x128x64_WarpShape64x64x64: + dispatch_gemm_config, + cutlass::gemm::GemmShape<64, 64, 64>>( + A, + B, + weight_scales, + biases, + C, + m, + n, + k, + gemm_config, + workspace, + workspace_bytes, + stream, + occupancy); + break; + case CutlassTileConfig::CtaShape128x256x64_WarpShape64x64x64: + dispatch_gemm_config, + cutlass::gemm::GemmShape<64, 64, 64>>( + A, + B, + weight_scales, + biases, + C, + m, + n, + k, + gemm_config, + workspace, + workspace_bytes, + stream, + occupancy); + break; + case CutlassTileConfig::Undefined: + throw std::runtime_error( + "[fpA_intB][dispatch_gemm_to_cutlass] gemm config undefined."); + break; + case CutlassTileConfig::ChooseWithHeuristic: + throw std::runtime_error( + "[fpA_intB][dispatch_gemm_to_cutlass] gemm config should have " + "already been set by heuristic."); + break; + default: + throw std::runtime_error( + "[fpA_intB][dispatch_gemm_to_cutlass] Config is invalid for mixed " + "type GEMM."); + break; + } +} + +template +CutlassFpAIntBGemmRunner::CutlassFpAIntBGemmRunner() { + // VLOG(3)<<__PRETTY_FUNCTION__; + int device{-1}; + check_cuda_error(cudaGetDevice(&device)); + sm_ = getSMVersion(); + check_cuda_error(cudaDeviceGetAttribute( + &multi_processor_count_, cudaDevAttrMultiProcessorCount, device)); +} + +template +CutlassFpAIntBGemmRunner::~CutlassFpAIntBGemmRunner() { + // VLOG(3)<<__PRETTY_FUNCTION__; +} + +template +template +void CutlassFpAIntBGemmRunner::dispatch_to_arch( + const T* A, + const WeightType* B, + const float* weight_scales, + const T* biases, + T* C, + int m, + int n, + int k, + CutlassGemmConfig gemm_config, + char* workspace_ptr, + const size_t workspace_bytes, + cudaStream_t stream, + int* occupancy) { + // VLOG(3)<<__PRETTY_FUNCTION__; + if (sm_ >= 70 && sm_ < 75) { +#if defined(USE_FPAINTB_GEMM_WITH_SM70) + dispatch_gemm_to_cutlass( + A, + B, + weight_scales, + biases, + C, + m, + n, + k, + workspace_ptr, + workspace_bytes, + gemm_config, + stream, + occupancy); +#else + throw std::runtime_error( + "[CutlassFpAIntBGemmRunner][GEMM Dispatch] Arch unsupported for " + "CUTLASS mixed type GEMM"); +#endif + } else if (sm_ >= 75 && sm_ < 80) { +#if defined(USE_FPAINTB_GEMM_WITH_SM75) + dispatch_gemm_to_cutlass( + A, + B, + weight_scales, + biases, + C, + m, + n, + k, + workspace_ptr, + workspace_bytes, + gemm_config, + stream, + occupancy); +#else + throw std::runtime_error( + "[CutlassFpAIntBGemmRunner][GEMM Dispatch] Arch unsupported for " + "CUTLASS mixed type GEMM"); +#endif + } else if (sm_ >= 80 && sm_ < 90) { +#if defined(USE_FPAINTB_GEMM_WITH_SM80) + dispatch_gemm_to_cutlass( + A, + B, + weight_scales, + biases, + C, + m, + n, + k, + workspace_ptr, + workspace_bytes, + gemm_config, + stream, + occupancy); +#else + throw std::runtime_error( + "[CutlassFpAIntBGemmRunner][GEMM Dispatch] Arch unsupported for " + "CUTLASS mixed type GEMM"); +#endif + } else { + throw std::runtime_error( + "[CutlassFpAIntBGemmRunner][GEMM Dispatch] Arch unsupported for " + "CUTLASS mixed type GEMM"); + } +} + +template +template +void CutlassFpAIntBGemmRunner::run_gemm( + const T* A, + const WeightType* B, + const float* weight_scales, + const T* biases, + T* C, + int m, + int n, + int k, + char* workspace_ptr, + const size_t workspace_bytes, + cudaStream_t stream) { + // VLOG(3)<<__PRETTY_FUNCTION__; + static constexpr bool is_weight_only = !std::is_same::value; + const bool is_weight_only_encoder = m >= 512 ? true : false; + std::vector candidate_configs = + get_candidate_configs(sm_, is_weight_only, is_weight_only_encoder, false); + std::vector occupancies(candidate_configs.size()); + + for (size_t ii = 0; ii < candidate_configs.size(); ++ii) { + dispatch_to_arch(A, + B, + weight_scales, + biases, + C, + m, + n, + k, + candidate_configs[ii], + workspace_ptr, + workspace_bytes, + stream, + &occupancies[ii]); + } + // Standard GEMM, so 1 "expert". We use the same function for MoE and regular + // FFN. + static constexpr int num_experts = 1; + CutlassGemmConfig chosen_config = + estimate_best_config_from_occupancies(candidate_configs, + occupancies, + m, + n, + k, + num_experts, + split_k_limit, + workspace_bytes, + multi_processor_count_, + is_weight_only); + + dispatch_to_arch(A, + B, + weight_scales, + biases, + C, + m, + n, + k, + chosen_config, + workspace_ptr, + workspace_bytes, + stream); +} + +template +void CutlassFpAIntBGemmRunner::gemm_bias_act( + const T* A, + const WeightType* B, + const float* weight_scales, + const T* biases, + T* C, + int m, + int n, + int k, + std::string activation_type, + char* workspace_ptr, + const size_t workspace_bytes, + cudaStream_t stream) { + // VLOG(3)<<__PRETTY_FUNCTION__; + if (activation_type == "gelu") { + run_gemm(A, + B, + weight_scales, + biases, + C, + m, + n, + k, + workspace_ptr, + workspace_bytes, + stream); + } else if (activation_type == "relu") { + run_gemm(A, + B, + weight_scales, + biases, + C, + m, + n, + k, + workspace_ptr, + workspace_bytes, + stream); + } else if (activation_type == "none") { + run_gemm(A, + B, + weight_scales, + biases, + C, + m, + n, + k, + workspace_ptr, + workspace_bytes, + stream); + } else { + throw std::runtime_error(("Invalid activation type.")); + } +} + +template +void CutlassFpAIntBGemmRunner::gemm(const T* A, + const WeightType* B, + const float* weight_scales, + T* C, + int m, + int n, + int k, + char* workspace_ptr, + const size_t workspace_bytes, + cudaStream_t stream) { + // VLOG(3)<<__PRETTY_FUNCTION__; + run_gemm(A, + B, + weight_scales, + nullptr, + C, + m, + n, + k, + workspace_ptr, + workspace_bytes, + stream); +} + +template +int CutlassFpAIntBGemmRunner::getWorkspaceSize(const int m, + const int n, + const int k) { + // VLOG(3)<<__PRETTY_FUNCTION__; // These are the min tile sizes for each + // config, which would launch the maximum number of blocks + const int max_grid_m = (m + 31) / 32; + const int max_grid_n = (n + 127) / 128; + // We need 4 bytes per block in the worst case. We launch split_k_limit in z + // dim. + return max_grid_m * max_grid_n * split_k_limit * 4; +} + +// =============================== Specialization T == WeightType +// ======================================= +template +void CutlassFpAIntBGemmRunner::gemm_bias_act( + const float* A, + const WeightType* B, + const float* weight_scales, + const float* biases, + float* C, + int m, + int n, + int k, + std::string activation_type, + char* workspace_ptr, + const size_t workspace_bytes, + cudaStream_t stream) { + throw std::runtime_error( + ("Attempting to run mixed gemm bias act when the types are the same is " + "an error.")); +} + +template +void CutlassFpAIntBGemmRunner::gemm( + const float* A, + const WeightType* B, + const float* weight_scales, + float* C, + int m, + int n, + int k, + char* workspace_ptr, + const size_t workspace_bytes, + cudaStream_t stream) { + throw std::runtime_error(( + "Attempting to run mixed gemm when the types are the same is an error.")); +} + +template +int CutlassFpAIntBGemmRunner::getWorkspaceSize(const int m, + const int n, + const int k) { + return 0; +} + +template class CutlassFpAIntBGemmRunner; +template class CutlassFpAIntBGemmRunner; +#ifdef PADDLE_CUDA_BF16 +template class CutlassFpAIntBGemmRunner<__nv_bfloat16, uint8_t>; +#endif +template class CutlassFpAIntBGemmRunner; +template class CutlassFpAIntBGemmRunner; +#ifdef PADDLE_CUDA_BF16 +template class CutlassFpAIntBGemmRunner<__nv_bfloat16, cutlass::uint4b_t>; +#endif +} // namespace phi diff --git a/paddle/phi/kernels/fusion/cutlass/cutlass_kernels/fpA_intB_gemm/fpA_intB_gemm_template.h b/paddle/phi/kernels/fusion/cutlass/cutlass_kernels/fpA_intB_gemm/fpA_intB_gemm_template.h index e0ce642cef243..f5862a8c58959 100644 --- a/paddle/phi/kernels/fusion/cutlass/cutlass_kernels/fpA_intB_gemm/fpA_intB_gemm_template.h +++ b/paddle/phi/kernels/fusion/cutlass/cutlass_kernels/fpA_intB_gemm/fpA_intB_gemm_template.h @@ -45,6 +45,7 @@ limitations under the License. */ #pragma GCC diagnostic pop #include "paddle/phi/kernels/fusion/cutlass/cutlass_kernels/cutlass_heuristic.h" +#include "paddle/phi/kernels/fusion/cutlass/cutlass_kernels/fpA_intB_gemm/autogen/arch_define.h" #include "paddle/phi/kernels/fusion/cutlass/cutlass_kernels/fpA_intB_gemm/fpA_intB_gemm.h" #include "paddle/phi/kernels/fusion/cutlass/utils/cuda_utils.h" namespace phi { @@ -221,6 +222,27 @@ void generic_mixed_gemm_kernelLauncher(const T* A, } } +template +void generic_mixed_gemm_kernelLauncher_template(const T* A, + const WeightType* B, + const float* weight_scales, + const T* biases, + T* C, + int m, + int n, + int k, + CutlassGemmConfig gemm_config, + char* workspace, + size_t workspace_bytes, + cudaStream_t stream, + int* occupancy); + template (A, - B, - weight_scales, - biases, - C, - m, - n, - k, - gemm_config, - workspace, - workspace_bytes, - stream, - occupancy); + generic_mixed_gemm_kernelLauncher_template(A, + B, + weight_scales, + biases, + C, + m, + n, + k, + gemm_config, + workspace, + workspace_bytes, + stream, + occupancy); } }; +#if defined(USE_FPAINTB_GEMM_WITH_SM80) template (A, - B, - weight_scales, - biases, - C, - m, - n, - k, - gemm_config, - workspace, - workspace_bytes, - stream, - occupancy); + generic_mixed_gemm_kernelLauncher_template(A, + B, + weight_scales, + biases, + C, + m, + n, + k, + gemm_config, + workspace, + workspace_bytes, + stream, + occupancy); } }; +#endif template ; - DispatcherStages2::dispatch(A, - B, - weight_scales, - biases, - C, - m, - n, - k, - gemm_config, - workspace, - workspace_bytes, - stream, - occupancy); - break; - case 3: - using DispatcherStages3 = dispatch_stages; - DispatcherStages3::dispatch(A, - B, - weight_scales, - biases, - C, - m, - n, - k, - gemm_config, - workspace, - workspace_bytes, - stream, - occupancy); - break; - case 4: - using DispatcherStages4 = dispatch_stages; - DispatcherStages4::dispatch(A, - B, - weight_scales, - biases, - C, - m, - n, - k, - gemm_config, - workspace, - workspace_bytes, - stream, - occupancy); - break; - default: - std::string err_msg = "dispatch_gemm_config does not support stages " + - std::to_string(gemm_config.stages); - throw std::runtime_error("[dispatch_gemm_config] " + err_msg); - break; - } -} + int* occupancy); template void dispatch_gemm_to_cutlass(const T* A, @@ -456,430 +406,6 @@ void dispatch_gemm_to_cutlass(const T* A, size_t workspace_bytes, CutlassGemmConfig gemm_config, cudaStream_t stream, - int* occupancy) { - // VLOG(3)<<__PRETTY_FUNCTION__; - // Note that SIMT configs are omitted here since they are not supported for - // fpA_intB. We also only instantiate configs here where threadblockShapeM == - // warpShapeM since those usually perform the best for mixed type gemms. - switch (gemm_config.tile_config) { - case CutlassTileConfig::CtaShape32x128x64_WarpShape32x32x64: - dispatch_gemm_config, - cutlass::gemm::GemmShape<32, 32, 64>>( - A, - B, - weight_scales, - biases, - C, - m, - n, - k, - gemm_config, - workspace, - workspace_bytes, - stream, - occupancy); - break; - case CutlassTileConfig::CtaShape64x128x64_WarpShape64x32x64: - dispatch_gemm_config, - cutlass::gemm::GemmShape<64, 32, 64>>( - A, - B, - weight_scales, - biases, - C, - m, - n, - k, - gemm_config, - workspace, - workspace_bytes, - stream, - occupancy); - break; - case CutlassTileConfig::CtaShape128x128x64_WarpShape128x32x64: - dispatch_gemm_config, - cutlass::gemm::GemmShape<128, 32, 64>>( - A, - B, - weight_scales, - biases, - C, - m, - n, - k, - gemm_config, - workspace, - workspace_bytes, - stream, - occupancy); - break; - // config for M_16000_N_12288_K_6144 in encoder - case CutlassTileConfig::CtaShape256x128x64_WarpShape64x64x64: - dispatch_gemm_config, - cutlass::gemm::GemmShape<64, 64, 64>>( - A, - B, - weight_scales, - biases, - C, - m, - n, - k, - gemm_config, - workspace, - workspace_bytes, - stream, - occupancy); - break; - case CutlassTileConfig::CtaShape128x256x64_WarpShape64x64x64: - dispatch_gemm_config, - cutlass::gemm::GemmShape<64, 64, 64>>( - A, - B, - weight_scales, - biases, - C, - m, - n, - k, - gemm_config, - workspace, - workspace_bytes, - stream, - occupancy); - break; - case CutlassTileConfig::Undefined: - throw std::runtime_error( - "[fpA_intB][dispatch_gemm_to_cutlass] gemm config undefined."); - break; - case CutlassTileConfig::ChooseWithHeuristic: - throw std::runtime_error( - "[fpA_intB][dispatch_gemm_to_cutlass] gemm config should have " - "already been set by heuristic."); - break; - default: - throw std::runtime_error( - "[fpA_intB][dispatch_gemm_to_cutlass] Config is invalid for mixed " - "type GEMM."); - break; - } -} - -template -CutlassFpAIntBGemmRunner::CutlassFpAIntBGemmRunner() { - // VLOG(3)<<__PRETTY_FUNCTION__; - int device{-1}; - check_cuda_error(cudaGetDevice(&device)); - sm_ = getSMVersion(); - check_cuda_error(cudaDeviceGetAttribute( - &multi_processor_count_, cudaDevAttrMultiProcessorCount, device)); -} - -template -CutlassFpAIntBGemmRunner::~CutlassFpAIntBGemmRunner() { - // VLOG(3)<<__PRETTY_FUNCTION__; -} - -template -template -void CutlassFpAIntBGemmRunner::dispatch_to_arch( - const T* A, - const WeightType* B, - const float* weight_scales, - const T* biases, - T* C, - int m, - int n, - int k, - CutlassGemmConfig gemm_config, - char* workspace_ptr, - const size_t workspace_bytes, - cudaStream_t stream, - int* occupancy) { - // VLOG(3)<<__PRETTY_FUNCTION__; - if (sm_ >= 70 && sm_ < 75) { - dispatch_gemm_to_cutlass( - A, - B, - weight_scales, - biases, - C, - m, - n, - k, - workspace_ptr, - workspace_bytes, - gemm_config, - stream, - occupancy); - } else if (sm_ >= 75 && sm_ < 80) { - dispatch_gemm_to_cutlass( - A, - B, - weight_scales, - biases, - C, - m, - n, - k, - workspace_ptr, - workspace_bytes, - gemm_config, - stream, - occupancy); - } else if (sm_ >= 80 && sm_ < 90) { - dispatch_gemm_to_cutlass( - A, - B, - weight_scales, - biases, - C, - m, - n, - k, - workspace_ptr, - workspace_bytes, - gemm_config, - stream, - occupancy); - } else { - throw std::runtime_error( - "[CutlassFpAIntBGemmRunner][GEMM Dispatch] Arch unsupported for " - "CUTLASS mixed type GEMM"); - } -} + int* occupancy); -template -template -void CutlassFpAIntBGemmRunner::run_gemm( - const T* A, - const WeightType* B, - const float* weight_scales, - const T* biases, - T* C, - int m, - int n, - int k, - char* workspace_ptr, - const size_t workspace_bytes, - cudaStream_t stream) { - // VLOG(3)<<__PRETTY_FUNCTION__; - static constexpr bool is_weight_only = !std::is_same::value; - const bool is_weight_only_encoder = m >= 512 ? true : false; - std::vector candidate_configs = - get_candidate_configs(sm_, is_weight_only, is_weight_only_encoder, false); - std::vector occupancies(candidate_configs.size()); - - for (size_t ii = 0; ii < candidate_configs.size(); ++ii) { - dispatch_to_arch(A, - B, - weight_scales, - biases, - C, - m, - n, - k, - candidate_configs[ii], - workspace_ptr, - workspace_bytes, - stream, - &occupancies[ii]); - } - // Standard GEMM, so 1 "expert". We use the same function for MoE and regular - // FFN. - static constexpr int num_experts = 1; - CutlassGemmConfig chosen_config = - estimate_best_config_from_occupancies(candidate_configs, - occupancies, - m, - n, - k, - num_experts, - split_k_limit, - workspace_bytes, - multi_processor_count_, - is_weight_only); - - dispatch_to_arch(A, - B, - weight_scales, - biases, - C, - m, - n, - k, - chosen_config, - workspace_ptr, - workspace_bytes, - stream); -} - -template -void CutlassFpAIntBGemmRunner::gemm_bias_act( - const T* A, - const WeightType* B, - const float* weight_scales, - const T* biases, - T* C, - int m, - int n, - int k, - std::string activation_type, - char* workspace_ptr, - const size_t workspace_bytes, - cudaStream_t stream) { - // VLOG(3)<<__PRETTY_FUNCTION__; - if (activation_type == "gelu") { - run_gemm(A, - B, - weight_scales, - biases, - C, - m, - n, - k, - workspace_ptr, - workspace_bytes, - stream); - } else if (activation_type == "relu") { - run_gemm(A, - B, - weight_scales, - biases, - C, - m, - n, - k, - workspace_ptr, - workspace_bytes, - stream); - } else if (activation_type == "none") { - run_gemm(A, - B, - weight_scales, - biases, - C, - m, - n, - k, - workspace_ptr, - workspace_bytes, - stream); - } else { - throw std::runtime_error(("Invalid activation type.")); - } -} - -template -void CutlassFpAIntBGemmRunner::gemm(const T* A, - const WeightType* B, - const float* weight_scales, - T* C, - int m, - int n, - int k, - char* workspace_ptr, - const size_t workspace_bytes, - cudaStream_t stream) { - // VLOG(3)<<__PRETTY_FUNCTION__; - run_gemm(A, - B, - weight_scales, - nullptr, - C, - m, - n, - k, - workspace_ptr, - workspace_bytes, - stream); -} - -template -int CutlassFpAIntBGemmRunner::getWorkspaceSize(const int m, - const int n, - const int k) { - // VLOG(3)<<__PRETTY_FUNCTION__; // These are the min tile sizes for each - // config, which would launch the maximum number of blocks - const int max_grid_m = (m + 31) / 32; - const int max_grid_n = (n + 127) / 128; - // We need 4 bytes per block in the worst case. We launch split_k_limit in z - // dim. - return max_grid_m * max_grid_n * split_k_limit * 4; -} - -// =============================== Specialization T == WeightType -// ======================================= -template -void CutlassFpAIntBGemmRunner::gemm_bias_act( - const float* A, - const WeightType* B, - const float* weight_scales, - const float* biases, - float* C, - int m, - int n, - int k, - std::string activation_type, - char* workspace_ptr, - const size_t workspace_bytes, - cudaStream_t stream) { - throw std::runtime_error( - ("Attempting to run mixed gemm bias act when the types are the same is " - "an error.")); -} - -template -void CutlassFpAIntBGemmRunner::gemm( - const float* A, - const WeightType* B, - const float* weight_scales, - float* C, - int m, - int n, - int k, - char* workspace_ptr, - const size_t workspace_bytes, - cudaStream_t stream) { - throw std::runtime_error(( - "Attempting to run mixed gemm when the types are the same is an error.")); -} - -template -int CutlassFpAIntBGemmRunner::getWorkspaceSize(const int m, - const int n, - const int k) { - return 0; -} - -template class CutlassFpAIntBGemmRunner; -template class CutlassFpAIntBGemmRunner; -#ifdef PADDLE_CUDA_BF16 -template class CutlassFpAIntBGemmRunner<__nv_bfloat16, uint8_t>; -#endif -template class CutlassFpAIntBGemmRunner; -template class CutlassFpAIntBGemmRunner; -#ifdef PADDLE_CUDA_BF16 -template class CutlassFpAIntBGemmRunner<__nv_bfloat16, cutlass::uint4b_t>; -#endif } // namespace phi diff --git a/paddle/phi/kernels/fusion/cutlass/cutlass_kernels/fpA_intB_gemm/generic_mixed_gemm_kernelLauncher.py b/paddle/phi/kernels/fusion/cutlass/cutlass_kernels/fpA_intB_gemm/generic_mixed_gemm_kernelLauncher.py new file mode 100644 index 0000000000000..4295057679d57 --- /dev/null +++ b/paddle/phi/kernels/fusion/cutlass/cutlass_kernels/fpA_intB_gemm/generic_mixed_gemm_kernelLauncher.py @@ -0,0 +1,214 @@ +# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import argparse +import re + +# this is a file's header part +CommonHead = ''' +// Generated by generic_mixed_gemm_kernelLauncher.py - Do not edit. + +#include "paddle/phi/kernels/fusion/cutlass/cutlass_kernels/fpA_intB_gemm/fpA_intB_gemm_template.h" + +namespace phi { +''' + +CommonTail = ''' +} // namespace phi + +''' +DispatchGemmConfigInstanceDeclare = """ +template<> +void generic_mixed_gemm_kernelLauncher_template<{T}, + {WeightType}, + {arch}, + {EpilogueTag}, + {ThreadblockShape}, + {WarpShape}, + {Stages}>( + const {T}* A, + const {WeightType}* B, + const float* weight_scales, + const {T}* biases, + {T}* C, + int m, + int n, + int k, + CutlassGemmConfig gemm_config, + char* workspace, + size_t workspace_bytes, + cudaStream_t stream, + int* occupancy) { + generic_mixed_gemm_kernelLauncher<{T}, + {WeightType}, + {arch}, + {EpilogueTag}, + {ThreadblockShape}, + {WarpShape}, + {Stages}>( + A, + B, + weight_scales, + biases, + C, + m, + n, + k, + gemm_config, + workspace, + workspace_bytes, + stream, + occupancy); +} +""" + +DefineHeader = """ +// Generated by generic_mixed_gemm_kernelLauncher.py - Do not edit. + +""" + +DefaultArch = [70, 75, 80] +epilogue_tags = ["bias", "biasFtGelu", "biasReLU", "noBias"] + +WeightTypes = ["uint8_t", "cutlass::uint4b_t"] +ThreadblockShapes = [ + "cutlass::gemm::GemmShape<32, 128, 64>", + "cutlass::gemm::GemmShape<64, 128, 64>", + "cutlass::gemm::GemmShape<128, 128, 64>", + "cutlass::gemm::GemmShape<256, 128, 64>", + "cutlass::gemm::GemmShape<128, 256, 64>", +] +WarpShapes = [ + "cutlass::gemm::GemmShape<32, 32, 64>", + "cutlass::gemm::GemmShape<64, 32, 64>", + "cutlass::gemm::GemmShape<128, 32, 64>", + "cutlass::gemm::GemmShape<64, 64, 64>", + "cutlass::gemm::GemmShape<64, 64, 64>", +] +StagesList = {70: [2], 75: [2], 80: [2, 3, 4]} + +ElementTypes = {"fp16": "half", "bf16": "__nv_bfloat16"} +Archs = { + 70: "cutlass::arch::Sm70", + 75: "cutlass::arch::Sm75", + 80: "cutlass::arch::Sm80", +} +EpilogueTags = { + "bias": "EpilogueOpBias", + "biasFtGelu": "EpilogueOpBiasFtGelu", + "biasReLU": "EpilogueOpBiasReLU", + "noBias": "EpilogueOpNoBias", +} + + +def SubstituteTemplate(template, values): + text = template + changed = True + while changed: + changed = False + for key, value in values.items(): + regex = "\\{%s\\}" % key + newtext = re.sub(regex, value, text) + if newtext != text: + changed = True + text = newtext + return text + + +def find_arch_range(archs): + compile_archs = [] + for arch in archs: + if arch >= 70 and arch < 75: + compile_archs.append(70) + elif arch >= 75 and arch < 80: + compile_archs.append(75) + elif arch >= 80 and arch < 90: + compile_archs.append(80) + compile_archs = list(set(compile_archs)) + compile_archs.sort() + return compile_archs + + +def convert_to_arch_list(archs): + archs = archs.lower().strip() + if archs == "all": + return DefaultArch + + archs = [int(s.strip()) for s in archs.split(';') if s.strip()] + archs = list(set(archs)) + return find_arch_range(archs) + + +def parse_args(): + parser = argparse.ArgumentParser( + description="The argument for generating the generic_mixed_gemm_kernelLauncher instance." + ) + parser.add_argument( + "--cuda_arch", + type=convert_to_arch_list, + default=convert_to_arch_list("All"), + help="The CUDA architecture to be generated.", + ) + args = parser.parse_args() + return args + + +# generate source cu +def generate_source_cu( + element_type: str, arch: int, epilogue_tag: str, stages: int +): + all_code = CommonHead + for WeightType in WeightTypes: + for i in range(len(ThreadblockShapes)): + value_dict = { + "T": ElementTypes[element_type], + "WeightType": WeightType, + "arch": Archs[arch], + "EpilogueTag": EpilogueTags[epilogue_tag], + "ThreadblockShape": ThreadblockShapes[i], + "WarpShape": WarpShapes[i], + "Stages": str(stages), + } + all_code += SubstituteTemplate( + DispatchGemmConfigInstanceDeclare, value_dict + ) + all_code += CommonTail + return all_code + + +if __name__ == "__main__": + args = parse_args() + archs = args.cuda_arch + header_all = DefineHeader + header_name = "autogen/arch_define.h" + if archs: + for arch in archs: + define_line = "#define USE_FPAINTB_GEMM_WITH_SM%s\n" % str(arch) + header_all += define_line + with open(header_name, "w") as f: + f.write(header_all) + f.close() + if archs: + for element_type in ElementTypes.keys(): + for arch in archs: + for epilogue_tag in EpilogueTags.keys(): + for stages in StagesList[arch]: + file_name = "autogen/generic_mixed_gemm_kernelLauncher_{}_sm{}_stages{}_{}.cu".format( + element_type, arch, stages, epilogue_tag + ) + all_code = generate_source_cu( + element_type, arch, epilogue_tag, stages + ) + with open(file_name, "w") as f: + f.write(all_code) + f.close() diff --git a/paddle/phi/kernels/fusion/cutlass/utils/cuda_utils.h b/paddle/phi/kernels/fusion/cutlass/utils/cuda_utils.h index 69e737fa21157..e2e3652258406 100644 --- a/paddle/phi/kernels/fusion/cutlass/utils/cuda_utils.h +++ b/paddle/phi/kernels/fusion/cutlass/utils/cuda_utils.h @@ -33,6 +33,7 @@ #include #include #include +#include #include #include #include @@ -41,6 +42,9 @@ #include #endif +#include "paddle/phi/common/bfloat16.h" +#include "paddle/phi/common/float16.h" + namespace phi { #define MAX_CONFIG_NUM 20 diff --git a/paddle/phi/kernels/gpu/activation_grad_kernel.cu b/paddle/phi/kernels/gpu/activation_grad_kernel.cu index a0695935de1bc..d592dfad0a52d 100644 --- a/paddle/phi/kernels/gpu/activation_grad_kernel.cu +++ b/paddle/phi/kernels/gpu/activation_grad_kernel.cu @@ -358,14 +358,14 @@ PD_REGISTER_KERNEL(relu_double_grad, PD_REGISTER_ACTIVATION_GRAD_KERNEL_WITH_COMPLEX(sin_grad, SinGradKernel) PD_REGISTER_ACTIVATION_GRAD_KERNEL_WITH_COMPLEX(cos_grad, CosGradKernel) PD_REGISTER_ACTIVATION_GRAD_KERNEL_WITH_COMPLEX(tan_grad, TanGradKernel) -PD_REGISTER_ACTIVATION_GRAD_KERNEL(acos_grad, AcosGradKernel) -PD_REGISTER_ACTIVATION_GRAD_KERNEL(asin_grad, AsinGradKernel) -PD_REGISTER_ACTIVATION_GRAD_KERNEL(atan_grad, AtanGradKernel) -PD_REGISTER_ACTIVATION_GRAD_KERNEL(sinh_grad, SinhGradKernel) -PD_REGISTER_ACTIVATION_GRAD_KERNEL(cosh_grad, CoshGradKernel) -PD_REGISTER_ACTIVATION_GRAD_KERNEL(asinh_grad, AsinhGradKernel) -PD_REGISTER_ACTIVATION_GRAD_KERNEL(acosh_grad, AcoshGradKernel) -PD_REGISTER_ACTIVATION_GRAD_KERNEL(atanh_grad, AtanhGradKernel) +PD_REGISTER_ACTIVATION_GRAD_KERNEL_WITH_COMPLEX(acos_grad, AcosGradKernel) +PD_REGISTER_ACTIVATION_GRAD_KERNEL_WITH_COMPLEX(asin_grad, AsinGradKernel) +PD_REGISTER_ACTIVATION_GRAD_KERNEL_WITH_COMPLEX(atan_grad, AtanGradKernel) +PD_REGISTER_ACTIVATION_GRAD_KERNEL_WITH_COMPLEX(sinh_grad, SinhGradKernel) +PD_REGISTER_ACTIVATION_GRAD_KERNEL_WITH_COMPLEX(cosh_grad, CoshGradKernel) +PD_REGISTER_ACTIVATION_GRAD_KERNEL_WITH_COMPLEX(asinh_grad, AsinhGradKernel) +PD_REGISTER_ACTIVATION_GRAD_KERNEL_WITH_COMPLEX(acosh_grad, AcoshGradKernel) +PD_REGISTER_ACTIVATION_GRAD_KERNEL_WITH_COMPLEX(atanh_grad, AtanhGradKernel) PD_REGISTER_ACTIVATION_GRAD_KERNEL_WITH_COMPLEX(tanh_grad, TanhGradKernel) PD_REGISTER_ACTIVATION_GRAD_KERNEL_WITH_COMPLEX(tanh_double_grad, TanhDoubleGradKernel) @@ -398,7 +398,9 @@ PD_REGISTER_KERNEL(exp_grad, int, int64_t, phi::dtype::float16, - phi::dtype::bfloat16) {} + phi::dtype::bfloat16, + phi::dtype::complex, + phi::dtype::complex) {} PD_REGISTER_ACTIVATION_GRAD_KERNEL(softshrink_grad, SoftShrinkGradKernel) PD_REGISTER_ACTIVATION_GRAD_KERNEL(hard_shrink_grad, HardShrinkGradKernel) @@ -415,7 +417,9 @@ PD_REGISTER_KERNEL(expm1_grad, float, double, phi::dtype::float16, - phi::dtype::bfloat16) {} + phi::dtype::bfloat16, + phi::dtype::complex, + phi::dtype::complex) {} PD_REGISTER_KERNEL(square_grad, GPU, diff --git a/paddle/phi/kernels/gpu/activation_kernel.cu b/paddle/phi/kernels/gpu/activation_kernel.cu index 061a02f531538..000428268bbb1 100644 --- a/paddle/phi/kernels/gpu/activation_kernel.cu +++ b/paddle/phi/kernels/gpu/activation_kernel.cu @@ -232,14 +232,14 @@ PD_REGISTER_KERNEL(relu, PD_REGISTER_ACTIVATION_KERNEL_WITH_COMPLEX(sin, SinKernel) PD_REGISTER_ACTIVATION_KERNEL_WITH_COMPLEX(cos, CosKernel) PD_REGISTER_ACTIVATION_KERNEL_WITH_COMPLEX(tan, TanKernel) -PD_REGISTER_ACTIVATION_KERNEL(acos, AcosKernel) -PD_REGISTER_ACTIVATION_KERNEL(asin, AsinKernel) -PD_REGISTER_ACTIVATION_KERNEL(atan, AtanKernel) -PD_REGISTER_ACTIVATION_KERNEL(sinh, SinhKernel) -PD_REGISTER_ACTIVATION_KERNEL(cosh, CoshKernel) -PD_REGISTER_ACTIVATION_KERNEL(asinh, AsinhKernel) -PD_REGISTER_ACTIVATION_KERNEL(acosh, AcoshKernel) -PD_REGISTER_ACTIVATION_KERNEL(atanh, AtanhKernel) +PD_REGISTER_ACTIVATION_KERNEL_WITH_COMPLEX(acos, AcosKernel) +PD_REGISTER_ACTIVATION_KERNEL_WITH_COMPLEX(asin, AsinKernel) +PD_REGISTER_ACTIVATION_KERNEL_WITH_COMPLEX(atan, AtanKernel) +PD_REGISTER_ACTIVATION_KERNEL_WITH_COMPLEX(sinh, SinhKernel) +PD_REGISTER_ACTIVATION_KERNEL_WITH_COMPLEX(cosh, CoshKernel) +PD_REGISTER_ACTIVATION_KERNEL_WITH_COMPLEX(asinh, AsinhKernel) +PD_REGISTER_ACTIVATION_KERNEL_WITH_COMPLEX(acosh, AcoshKernel) +PD_REGISTER_ACTIVATION_KERNEL_WITH_COMPLEX(atanh, AtanhKernel) PD_REGISTER_ACTIVATION_KERNEL_WITH_COMPLEX(tanh, TanhKernel) PD_REGISTER_ACTIVATION_KERNEL(hardtanh, HardTanhKernel) PD_REGISTER_ACTIVATION_KERNEL(thresholded_relu, ThresholdedReluKernel) @@ -261,7 +261,9 @@ PD_REGISTER_KERNEL(exp, int, int64_t, phi::dtype::float16, - phi::dtype::bfloat16) {} + phi::dtype::bfloat16, + phi::dtype::complex, + phi::dtype::complex) {} PD_REGISTER_KERNEL(expm1, GPU, ALL_LAYOUT, @@ -271,7 +273,9 @@ PD_REGISTER_KERNEL(expm1, int, int64_t, phi::dtype::float16, - phi::dtype::bfloat16) {} + phi::dtype::bfloat16, + phi::dtype::complex, + phi::dtype::complex) {} PD_REGISTER_KERNEL(square, GPU, ALL_LAYOUT, diff --git a/paddle/phi/kernels/gpu/bincount_kernel.cu b/paddle/phi/kernels/gpu/bincount_kernel.cu index 8ce5de402cc22..04f80970a40d3 100644 --- a/paddle/phi/kernels/gpu/bincount_kernel.cu +++ b/paddle/phi/kernels/gpu/bincount_kernel.cu @@ -34,13 +34,12 @@ __global__ void KernelBincount(const InputT* input, const bool has_weights, const T* weights, OutT* output) { - if (!has_weights) { - for (int i = threadIdx.x; i < total_elements; i += blockDim.x) { - phi::CudaAtomicAdd(&output[input[i]], 1L); - } - } else { - for (int i = threadIdx.x; i < total_elements; i += blockDim.x) { - phi::CudaAtomicAdd(&output[input[i]], static_cast(weights[i])); + int tid = blockIdx.x * blockDim.x + threadIdx.x; + if (tid < total_elements) { + if (!has_weights) { + phi::CudaAtomicAdd(&output[input[tid]], 1L); + } else { + phi::CudaAtomicAdd(&output[input[tid]], static_cast(weights[tid])); } } } diff --git a/paddle/phi/kernels/gpu/flip_kernel.cu b/paddle/phi/kernels/gpu/flip_kernel.cu index 812d68df92d93..f271eba26e0ab 100644 --- a/paddle/phi/kernels/gpu/flip_kernel.cu +++ b/paddle/phi/kernels/gpu/flip_kernel.cu @@ -21,21 +21,26 @@ namespace phi { -template -__global__ void flip_cuda_kernel(const int64_t N, - const T* in_data, - T* out_data, - phi::Array shape, - phi::Array stride, - phi::Array flip_dims, - int flip_dims_size) { - int idx = blockIdx.x * blockDim.x + threadIdx.x; - if (idx >= N) { +template +__global__ void FlipCudaKernel(const T* in_data, + T* out_data, + phi::Array shape, + phi::Array stride, + phi::Array flip_dims, + const int rank, + const int64_t numel, + const int flip_dims_size) { + int64_t idx = + static_cast(blockIdx.x) * static_cast(blockDim.x) + + static_cast(threadIdx.x); + if (idx >= numel) { return; } - int cur_indices = idx, rem = 0, dst_offset = 0; - for (int i = 0; i < Rank; ++i) { + int64_t cur_indices = idx; + int64_t rem = 0; + int64_t dst_offset = 0; + for (int i = 0; i < rank; ++i) { int64_t temp = cur_indices; cur_indices = cur_indices / stride[i]; rem = temp - cur_indices * stride[i]; @@ -51,91 +56,48 @@ __global__ void flip_cuda_kernel(const int64_t N, out_data[idx] = in_data[dst_offset]; } -template -void LaunchFlipCudaKernel(const Context& dev_ctx, - const DenseTensor& x, - const std::vector& axis, - DenseTensor* out) { +template +void FlipKernel(const Context& dev_ctx, + const DenseTensor& x, + const std::vector& axis, + DenseTensor* out) { auto* in_data = x.data(); auto* out_data = dev_ctx.template Alloc(out); auto x_dims = x.dims(); - const int total_dims = x_dims.size(); + const int rank = x_dims.size(); const int64_t numel = x.numel(); - auto config = phi::backends::gpu::GetGpuLaunchConfig1D(dev_ctx, numel); - auto x_stride = phi::stride(x_dims); - phi::Array stride_a; - phi::Array shape_a; - phi::Array flip_dims_a; size_t flip_dims_size = axis.size(); + auto x_stride = phi::stride(x_dims); - for (size_t idx = 0; idx < N; ++idx) { - stride_a[idx] = x_stride[idx]; - shape_a[idx] = x_dims[idx]; - flip_dims_a[idx] = idx < flip_dims_size ? axis[idx] : 0; - } + phi::Array stride_array; + phi::Array shape_array; + phi::Array flip_dims_array; - for (size_t i = 0; i < flip_dims_a.size(); ++i) { - if (flip_dims_a[i] < 0) { - flip_dims_a[i] += total_dims; + for (int i = 0; i < rank; ++i) { + stride_array[i] = x_stride[i]; + shape_array[i] = x_dims[i]; + if (i < flip_dims_size) { + flip_dims_array[i] = axis[i] < 0 ? axis[i] + rank : axis[i]; + } else { + flip_dims_array[i] = 0; } } - flip_cuda_kernel + + auto config = phi::backends::gpu::GetGpuLaunchConfig1D(dev_ctx, numel); + FlipCudaKernel <<>>( - numel, in_data, out_data, - shape_a, - stride_a, - flip_dims_a, + shape_array, + stride_array, + flip_dims_array, + rank, + numel, flip_dims_size); } -template -void FlipKernel(const Context& dev_ctx, - const DenseTensor& x, - const std::vector& axis, - DenseTensor* out) { - const size_t total_dims = x.dims().size(); - switch (total_dims) { - case 0: - LaunchFlipCudaKernel(dev_ctx, x, axis, out); - break; - case 1: - LaunchFlipCudaKernel(dev_ctx, x, axis, out); - break; - case 2: - LaunchFlipCudaKernel(dev_ctx, x, axis, out); - break; - case 3: - LaunchFlipCudaKernel(dev_ctx, x, axis, out); - break; - case 4: - LaunchFlipCudaKernel(dev_ctx, x, axis, out); - break; - case 5: - LaunchFlipCudaKernel(dev_ctx, x, axis, out); - break; - case 6: - LaunchFlipCudaKernel(dev_ctx, x, axis, out); - break; - case 7: - LaunchFlipCudaKernel(dev_ctx, x, axis, out); - break; - case 8: - LaunchFlipCudaKernel(dev_ctx, x, axis, out); - break; - case 9: - LaunchFlipCudaKernel(dev_ctx, x, axis, out); - break; - default: - PADDLE_THROW(phi::errors::InvalidArgument( - "dims of input tensor should be less than 10, But received" - "%d", - x.dims().size())); - } -} } // namespace phi PD_REGISTER_KERNEL(flip, diff --git a/paddle/phi/kernels/gpu/fold_grad_kernel.cu b/paddle/phi/kernels/gpu/fold_grad_kernel.cu index ad469dd7981de..1e3cceb04dd0d 100644 --- a/paddle/phi/kernels/gpu/fold_grad_kernel.cu +++ b/paddle/phi/kernels/gpu/fold_grad_kernel.cu @@ -18,5 +18,11 @@ #include "paddle/phi/core/kernel_registry.h" #include "paddle/phi/kernels/impl/fold_grad_kernel_impl.h" -PD_REGISTER_KERNEL( - fold_grad, GPU, ALL_LAYOUT, phi::FoldGradKernel, float, double) {} +PD_REGISTER_KERNEL(fold_grad, + GPU, + ALL_LAYOUT, + phi::FoldGradKernel, + float, + double, + phi::dtype::complex, + phi::dtype::complex) {} diff --git a/paddle/phi/kernels/gpu/fold_kernel.cu b/paddle/phi/kernels/gpu/fold_kernel.cu index b53ef402150c2..2e21a121a0cc6 100644 --- a/paddle/phi/kernels/gpu/fold_kernel.cu +++ b/paddle/phi/kernels/gpu/fold_kernel.cu @@ -18,4 +18,11 @@ #include "paddle/phi/core/kernel_registry.h" #include "paddle/phi/kernels/impl/fold_kernel_impl.h" -PD_REGISTER_KERNEL(fold, GPU, ALL_LAYOUT, phi::FoldKernel, float, double) {} +PD_REGISTER_KERNEL(fold, + GPU, + ALL_LAYOUT, + phi::FoldKernel, + float, + double, + phi::dtype::complex, + phi::dtype::complex) {} diff --git a/paddle/phi/kernels/gpu/index_put_grad_kernel.cu b/paddle/phi/kernels/gpu/index_put_grad_kernel.cu index 8f2eba7185293..7e584e5c10318 100644 --- a/paddle/phi/kernels/gpu/index_put_grad_kernel.cu +++ b/paddle/phi/kernels/gpu/index_put_grad_kernel.cu @@ -24,20 +24,23 @@ namespace phi { -template -__global__ void set_zero_cuda_kernel(const int64_t N, - int64_t** indices, - phi::Array stride, - phi::Array shape, - T* out) { - int64_t idx = threadIdx.x + blockDim.x * blockIdx.x; - int64_t cur_ix = 0; - - if (idx >= N) { +template +__global__ void SetZeroCudaKernel(int64_t** indices, + phi::Array stride, + phi::Array shape, + const int rank, + const int64_t numel, + T* out) { + int64_t idx = + static_cast(threadIdx.x) + + static_cast(blockDim.x) * static_cast(blockIdx.x); + if (idx >= numel) { return; } + + int64_t cur_ix = 0; int64_t offset = 0; - for (int i = 0; i < Rank; ++i) { + for (int i = 0; i < rank; ++i) { cur_ix = (static_cast(*(indices[i] + idx))); if (cur_ix < 0) { cur_ix += shape[i]; @@ -48,21 +51,25 @@ __global__ void set_zero_cuda_kernel(const int64_t N, *(out + offset) = 0; } -template -__global__ void index_put_grad_cuda_kernel(const int64_t N, - const T* out_grad, - int64_t** indices, - phi::Array stride, - phi::Array shape, - T* value_grad) { - int64_t idx = threadIdx.x + blockDim.x * blockIdx.x; - int64_t cur_ix = 0; - - if (idx >= N) { +template +__global__ void IndexPutGradCudaKernel( + const T* out_grad, + int64_t** indices, + phi::Array stride, + phi::Array shape, + const int rank, + const int64_t numel, + T* value_grad) { + int64_t idx = + static_cast(threadIdx.x) + + static_cast(blockDim.x) * static_cast(blockIdx.x); + if (idx >= numel) { return; } + + int64_t cur_ix = 0; int64_t offset = 0; - for (int i = 0; i < Rank; ++i) { + for (int i = 0; i < rank; ++i) { cur_ix = (static_cast(*(indices[i] + idx))); if (cur_ix < 0) { cur_ix += shape[i]; @@ -73,12 +80,13 @@ __global__ void index_put_grad_cuda_kernel(const int64_t N, *(value_grad + idx) = *(out_grad + offset); } -template +template void LaunchIndexPutGradCudaKernel( const Context& dev_ctx, const std::vector& indices, const DenseTensor& out_grad, - bool accumulate, + const int rank, + const bool accumulate, DenseTensor* value_grad, DenseTensor* x_grad) { if (x_grad) { @@ -87,43 +95,41 @@ void LaunchIndexPutGradCudaKernel( T* x_grad_data = x_grad->data(); auto x_grad_dims = x_grad->dims(); - const int64_t numel = indices[0]->numel(); - auto config = phi::backends::gpu::GetGpuLaunchConfig1D(dev_ctx, numel); auto x_grad_stride = phi::stride(x_grad_dims); - phi::Array stride_a; - phi::Array shape_a; - - for (size_t idx = 0; idx < Rank; ++idx) { - stride_a[idx] = x_grad_stride[idx]; - shape_a[idx] = x_grad_dims[idx]; + phi::Array stride_array; + phi::Array shape_array; + for (int i = 0; i < rank; ++i) { + stride_array[i] = x_grad_stride[i]; + shape_array[i] = x_grad_dims[i]; } + const int64_t numel = indices[0]->numel(); auto pd_indices = funcs::GetDevicePointerArray(dev_ctx, indices); - set_zero_cuda_kernel<<>>( - numel, pd_indices, stride_a, shape_a, x_grad_data); + auto config = phi::backends::gpu::GetGpuLaunchConfig1D(dev_ctx, numel); + SetZeroCudaKernel<<>>( + pd_indices, stride_array, shape_array, rank, numel, x_grad_data); } } auto out_grad_dims = out_grad.dims(); - const int64_t numel = indices[0]->numel(); - auto config = phi::backends::gpu::GetGpuLaunchConfig1D(dev_ctx, numel); auto out_grad_stride = phi::stride(out_grad_dims); - phi::Array stride_a; - phi::Array shape_a; - - for (size_t idx = 0; idx < Rank; ++idx) { - stride_a[idx] = out_grad_stride[idx]; - shape_a[idx] = out_grad_dims[idx]; + phi::Array stride_array; + phi::Array shape_array; + for (int i = 0; i < rank; ++i) { + stride_array[i] = out_grad_stride[i]; + shape_array[i] = out_grad_dims[i]; } + const int64_t numel = indices[0]->numel(); auto pd_indices = funcs::GetDevicePointerArray(dev_ctx, indices); + auto config = phi::backends::gpu::GetGpuLaunchConfig1D(dev_ctx, numel); if (value_grad) { if (value_grad->numel() == 1) { @@ -133,16 +139,16 @@ void LaunchIndexPutGradCudaKernel( T* tmp_value_grad_data = dev_ctx.template Alloc(&tmp_value_grad); auto out_grad_data = out_grad.data(); - index_put_grad_cuda_kernel - <<>>(numel, - out_grad_data, - pd_indices, - stride_a, - shape_a, - tmp_value_grad_data); + IndexPutGradCudaKernel<<>>(out_grad_data, + pd_indices, + stride_array, + shape_array, + rank, + numel, + tmp_value_grad_data); std::vector v_dims(tmp_value_grad.dims().size()); std::iota(v_dims.begin(), v_dims.end(), 0); @@ -157,11 +163,16 @@ void LaunchIndexPutGradCudaKernel( T* value_grad_data = dev_ctx.template Alloc(value_grad); auto out_grad_data = out_grad.data(); - index_put_grad_cuda_kernel<<>>( - numel, out_grad_data, pd_indices, stride_a, shape_a, value_grad_data); + IndexPutGradCudaKernel<<>>(out_grad_data, + pd_indices, + stride_array, + shape_array, + rank, + numel, + value_grad_data); } else { DenseTensor tmp_value_grad(value_grad->dtype()); tmp_value_grad.Resize(indices[0]->dims()); @@ -169,16 +180,16 @@ void LaunchIndexPutGradCudaKernel( T* tmp_value_grad_data = dev_ctx.template Alloc(&tmp_value_grad); auto out_grad_data = out_grad.data(); - index_put_grad_cuda_kernel - <<>>(numel, - out_grad_data, - pd_indices, - stride_a, - shape_a, - tmp_value_grad_data); + IndexPutGradCudaKernel<<>>(out_grad_data, + pd_indices, + stride_array, + shape_array, + rank, + numel, + tmp_value_grad_data); std::vector after_dims = phi::vectorize(tmp_value_grad.dims()); std::vector before_dims = phi::vectorize(value_grad->dims()); @@ -234,7 +245,6 @@ void IndexPutGradKernel(const Context& dev_ctx, return; } - const size_t total_dims = x.dims().size(); auto bd_dim = funcs::BroadCastTensorsDims(int_indices_v); std::vector res_dim_v(phi::vectorize(bd_dim)); @@ -256,37 +266,9 @@ void IndexPutGradKernel(const Context& dev_ctx, bd_dim, &res_dim_v); - switch (total_dims) { - case 1: - LaunchIndexPutGradCudaKernel( - dev_ctx, res_indices_v, out_grad, accumulate, value_grad, x_grad); - break; - case 2: - LaunchIndexPutGradCudaKernel( - dev_ctx, res_indices_v, out_grad, accumulate, value_grad, x_grad); - break; - case 3: - LaunchIndexPutGradCudaKernel( - dev_ctx, res_indices_v, out_grad, accumulate, value_grad, x_grad); - break; - case 4: - LaunchIndexPutGradCudaKernel( - dev_ctx, res_indices_v, out_grad, accumulate, value_grad, x_grad); - break; - case 5: - LaunchIndexPutGradCudaKernel( - dev_ctx, res_indices_v, out_grad, accumulate, value_grad, x_grad); - break; - case 6: - LaunchIndexPutGradCudaKernel( - dev_ctx, res_indices_v, out_grad, accumulate, value_grad, x_grad); - break; - default: - PADDLE_THROW(phi::errors::InvalidArgument( - "dims of input tensor should be less than 7, But received" - "%d", - x.dims().size())); - } + const int rank = x.dims().size(); + LaunchIndexPutGradCudaKernel( + dev_ctx, res_indices_v, out_grad, rank, accumulate, value_grad, x_grad); } } // namespace phi diff --git a/paddle/phi/kernels/gpu/index_put_kernel.cu b/paddle/phi/kernels/gpu/index_put_kernel.cu index 4244e755b6597..ccbd19aaba681 100644 --- a/paddle/phi/kernels/gpu/index_put_kernel.cu +++ b/paddle/phi/kernels/gpu/index_put_kernel.cu @@ -21,24 +21,27 @@ namespace phi { -template -__global__ void index_put_cuda_kernel(const int64_t N, - const T* x, - const T* vals, - int64_t** indices, - phi::Array stride, - phi::Array shape, - int64_t is_single_val_tensor, - bool accumulate, - T* out) { - int64_t idx = threadIdx.x + blockDim.x * blockIdx.x; +template +__global__ void IndexPutCudaKernel(const T* x, + const T* vals, + int64_t** indices, + phi::Array stride, + phi::Array shape, + const int rank, + const int64_t numel, + const int64_t is_single_val_tensor, + const bool accumulate, + T* out) { + int64_t idx = + static_cast(threadIdx.x) + + static_cast(blockDim.x) * static_cast(blockIdx.x); int64_t cur_ix = 0; - if (idx >= N) { + if (idx >= numel) { return; } int64_t offset = 0; - for (int i = 0; i < Rank; ++i) { + for (int i = 0; i < rank; ++i) { cur_ix = (static_cast(*(indices[i] + idx))); if (cur_ix < 0) { cur_ix += shape[i]; @@ -53,7 +56,7 @@ __global__ void index_put_cuda_kernel(const int64_t N, } } -template +template void LaunchIndexPutCudaKernel(const Context& dev_ctx, const DenseTensor& x, const std::vector& indices, @@ -62,38 +65,39 @@ void LaunchIndexPutCudaKernel(const Context& dev_ctx, DenseTensor* out) { auto* x_data = x.data(); auto* val_data = value.data(); + bool is_initialized = out->initialized(); T* out_data = dev_ctx.template Alloc(out); - if (!is_initialized) { phi::Copy(dev_ctx, x, dev_ctx.GetPlace(), false, out); } auto x_dims = x.dims(); - const int64_t numel = indices[0]->numel(); - auto config = phi::backends::gpu::GetGpuLaunchConfig1D(dev_ctx, numel); + const int rank = x_dims.size(); auto x_stride = phi::stride(x_dims); - phi::Array stride_a; - phi::Array shape_a; - - for (size_t idx = 0; idx < Rank; ++idx) { - stride_a[idx] = x_stride[idx]; - shape_a[idx] = x_dims[idx]; + phi::Array stride_array; + phi::Array shape_array; + for (int i = 0; i < rank; ++i) { + stride_array[i] = x_stride[i]; + shape_array[i] = x_dims[i]; } int64_t is_single_val_tensor = (value.numel() == 1) ? 0 : INT64_MAX; - + const int64_t numel = indices[0]->numel(); auto pd_indices = funcs::GetDevicePointerArray(dev_ctx, indices); - index_put_cuda_kernel + + auto config = phi::backends::gpu::GetGpuLaunchConfig1D(dev_ctx, numel); + IndexPutCudaKernel <<>>( - numel, x_data, val_data, pd_indices, - stride_a, - shape_a, + stride_array, + shape_array, + rank, + numel, is_single_val_tensor, accumulate, out_data); @@ -124,7 +128,6 @@ void IndexPutKernel(const Context& dev_ctx, } return; } - const size_t total_dims = x.dims().size(); auto bd_dim = funcs::BroadCastTensorsDims(int_indices_v); std::vector res_dim_v(phi::vectorize(bd_dim)); @@ -158,37 +161,8 @@ void IndexPutKernel(const Context& dev_ctx, ptr_value = &value; } - switch (total_dims) { - case 1: - LaunchIndexPutCudaKernel( - dev_ctx, x, res_indices_v, *ptr_value, accumulate, out); - break; - case 2: - LaunchIndexPutCudaKernel( - dev_ctx, x, res_indices_v, *ptr_value, accumulate, out); - break; - case 3: - LaunchIndexPutCudaKernel( - dev_ctx, x, res_indices_v, *ptr_value, accumulate, out); - break; - case 4: - LaunchIndexPutCudaKernel( - dev_ctx, x, res_indices_v, *ptr_value, accumulate, out); - break; - case 5: - LaunchIndexPutCudaKernel( - dev_ctx, x, res_indices_v, *ptr_value, accumulate, out); - break; - case 6: - LaunchIndexPutCudaKernel( - dev_ctx, x, res_indices_v, *ptr_value, accumulate, out); - break; - default: - PADDLE_THROW(phi::errors::InvalidArgument( - "dims of input tensor should be less than 7, But received" - "%d", - x.dims().size())); - } + LaunchIndexPutCudaKernel( + dev_ctx, x, res_indices_v, *ptr_value, accumulate, out); } } // namespace phi diff --git a/paddle/phi/kernels/gpu/matmul_kernel.cu b/paddle/phi/kernels/gpu/matmul_kernel.cu index 71095bf783b0b..5882cab4f4ee5 100644 --- a/paddle/phi/kernels/gpu/matmul_kernel.cu +++ b/paddle/phi/kernels/gpu/matmul_kernel.cu @@ -19,6 +19,7 @@ limitations under the License. */ #include "paddle/phi/core/kernel_registry.h" #include "paddle/phi/kernels/impl/matmul_kernel_impl.h" +#ifdef PADDLE_WITH_CUDA PD_REGISTER_KERNEL(matmul, GPU, ALL_LAYOUT, @@ -30,11 +31,46 @@ PD_REGISTER_KERNEL(matmul, phi::dtype::float16, phi::dtype::bfloat16, phi::dtype::complex, - phi::dtype::complex) {} - -PD_REGISTER_KERNEL( - matmul_int8, GPU, ALL_LAYOUT, phi::MatmulInt8Kernel, int8_t) {} + phi::dtype::complex, + int8_t) { + if (kernel_key.dtype() == phi::DataType::INT8) { + kernel->OutputAt(0).SetDataType(phi::DataType::INT32); + } +} +#else +PD_REGISTER_KERNEL(matmul, + GPU, + ALL_LAYOUT, + phi::MatmulKernel, + float, + double, + int32_t, + int64_t, + phi::dtype::float16, + phi::dtype::bfloat16, + phi::dtype::complex, + phi::dtype::complex) { + if (kernel_key.dtype() == phi::DataType::INT8) { + kernel->OutputAt(0).SetDataType(phi::DataType::INT32); + } +} +#endif +#ifdef PADDLE_WITH_CUDA +PD_REGISTER_KERNEL(matmul_with_flatten, + GPU, + ALL_LAYOUT, + phi::MatmulWithFlattenKernel, + int8_t, + float, + double, + phi::dtype::bfloat16, + phi::dtype::float16) { + if (kernel_key.dtype() == phi::DataType::INT8) { + kernel->OutputAt(0).SetDataType(phi::DataType::INT32); + } +} +#else PD_REGISTER_KERNEL(matmul_with_flatten, GPU, ALL_LAYOUT, @@ -42,4 +78,9 @@ PD_REGISTER_KERNEL(matmul_with_flatten, float, double, phi::dtype::bfloat16, - phi::dtype::float16) {} + phi::dtype::float16) { + if (kernel_key.dtype() == phi::DataType::INT8) { + kernel->OutputAt(0).SetDataType(phi::DataType::INT32); + } +} +#endif diff --git a/paddle/phi/kernels/gpu/roll_grad_kernel.cu b/paddle/phi/kernels/gpu/roll_grad_kernel.cu index 3a523e58ca862..71d1cd356a269 100644 --- a/paddle/phi/kernels/gpu/roll_grad_kernel.cu +++ b/paddle/phi/kernels/gpu/roll_grad_kernel.cu @@ -22,8 +22,6 @@ namespace phi { -using phi::PADDLE_CUDA_NUM_THREADS; - template void RollGradKernel(const Context& dev_ctx, const DenseTensor& x, @@ -31,23 +29,23 @@ void RollGradKernel(const Context& dev_ctx, const IntArray& shifts, const std::vector& axis, DenseTensor* x_grad) { - auto* in_data = out_grad.data(); - T* out_data = dev_ctx.template Alloc(x_grad); - int64_t numel = out_grad.numel(); - auto stream = dev_ctx.stream(); + auto* out_grad_data = out_grad.data(); + T* x_grad_data = dev_ctx.template Alloc(x_grad); auto shifts_data = shifts.GetData(); - size_t nums = shifts_data.size(); + int rank = shifts_data.size(); + + int64_t numel = out_grad.numel(); auto input_dim = out_grad.dims(); auto stride_dim = phi::stride(input_dim); - std::vector strides(nums), sizes(nums); + std::vector strides(rank), sizes(rank); if (axis.size() == 0) { strides[0] = 1; sizes[0] = numel; shifts_data[0] = ((-shifts_data[0]) % numel + numel) % numel; } else { - for (size_t i = 0; i < nums; i++) { + for (int i = 0; i < rank; i++) { int dim = axis[i] >= 0 ? axis[i] : axis[i] + input_dim.size(); int64_t size = input_dim[dim]; if (size != 0) { @@ -58,22 +56,14 @@ void RollGradKernel(const Context& dev_ctx, } } - switch (nums) { - CALL_ROLL_CUDA_KERNEL(1); - CALL_ROLL_CUDA_KERNEL(2); - CALL_ROLL_CUDA_KERNEL(3); - CALL_ROLL_CUDA_KERNEL(4); - CALL_ROLL_CUDA_KERNEL(5); - CALL_ROLL_CUDA_KERNEL(6); - CALL_ROLL_CUDA_KERNEL(7); - CALL_ROLL_CUDA_KERNEL(8); - CALL_ROLL_CUDA_KERNEL(9); - default: - PADDLE_THROW(phi::errors::InvalidArgument( - "shifts.size() should be less than 10, But received shifts.size() " - "= %d", - shifts_data.size())); - } + LaunchRollKernel(dev_ctx, + out_grad_data, + x_grad_data, + rank, + numel, + shifts_data, + strides, + sizes); } } // namespace phi diff --git a/paddle/phi/kernels/gpu/roll_kernel.cu b/paddle/phi/kernels/gpu/roll_kernel.cu index 0e87713df73aa..cf4f87ac11854 100644 --- a/paddle/phi/kernels/gpu/roll_kernel.cu +++ b/paddle/phi/kernels/gpu/roll_kernel.cu @@ -23,8 +23,6 @@ namespace phi { -using phi::PADDLE_CUDA_NUM_THREADS; - template void RollKernel(const Context& dev_ctx, const DenseTensor& x, @@ -33,22 +31,21 @@ void RollKernel(const Context& dev_ctx, DenseTensor* out) { auto* in_data = x.data(); T* out_data = dev_ctx.template Alloc(out); - int64_t numel = x.numel(); - auto stream = dev_ctx.stream(); auto shifts_data = shifts.GetData(); + int rank = shifts_data.size(); - size_t nums = shifts_data.size(); + int64_t numel = x.numel(); auto input_dim = x.dims(); auto stride_dim = phi::stride(input_dim); - std::vector strides(nums), sizes(nums); + std::vector strides(rank), sizes(rank); if (axis.size() == 0) { strides[0] = 1; sizes[0] = numel; shifts_data[0] = (shifts_data[0] % numel + numel) % numel; } else { - for (size_t i = 0; i < nums; i++) { + for (int i = 0; i < rank; i++) { int dim = axis[i] >= 0 ? axis[i] : axis[i] + input_dim.size(); int64_t size = input_dim[dim]; @@ -60,22 +57,8 @@ void RollKernel(const Context& dev_ctx, } } - switch (nums) { - CALL_ROLL_CUDA_KERNEL(1); - CALL_ROLL_CUDA_KERNEL(2); - CALL_ROLL_CUDA_KERNEL(3); - CALL_ROLL_CUDA_KERNEL(4); - CALL_ROLL_CUDA_KERNEL(5); - CALL_ROLL_CUDA_KERNEL(6); - CALL_ROLL_CUDA_KERNEL(7); - CALL_ROLL_CUDA_KERNEL(8); - CALL_ROLL_CUDA_KERNEL(9); - default: - PADDLE_THROW(phi::errors::InvalidArgument( - "shifts.size() should be less than 10, But received shifts.size() " - "= %d", - shifts_data.size())); - } + LaunchRollKernel( + dev_ctx, in_data, out_data, rank, numel, shifts_data, strides, sizes); } } // namespace phi diff --git a/paddle/phi/kernels/gpu/roll_kernel_impl.h b/paddle/phi/kernels/gpu/roll_kernel_impl.h index d3aa8798008a9..38e2a6ff669ad 100644 --- a/paddle/phi/kernels/gpu/roll_kernel_impl.h +++ b/paddle/phi/kernels/gpu/roll_kernel_impl.h @@ -22,23 +22,25 @@ namespace phi { using phi::PADDLE_CUDA_NUM_THREADS; -template +template __global__ void RollCudaKernel(const T* input, T* output, - int64_t N, - phi::Array shifts, - phi::Array strides, - phi::Array sizes) { - int64_t idx = blockIdx.x * blockDim.x + threadIdx.x; - if (idx >= N) { + const int rank, + const int64_t numel, + phi::Array shifts, + phi::Array strides, + phi::Array sizes) { + int64_t idx = + static_cast(blockIdx.x) * static_cast(blockDim.x) + + static_cast(threadIdx.x); + if (idx >= numel) { return; } int64_t output_idx = idx; int64_t new_dim_idx = 0; -#pragma unroll - for (size_t i = 0; i < Rank; i++) { + for (size_t i = 0; i < rank; i++) { new_dim_idx = (output_idx / strides[i]) % sizes[i] + shifts[i]; if (new_dim_idx >= sizes[i]) { output_idx += (shifts[i] - sizes[i]) * strides[i]; @@ -49,22 +51,33 @@ __global__ void RollCudaKernel(const T* input, output[output_idx] = input[idx]; } -#define CALL_ROLL_CUDA_KERNEL(N) \ - case N: { \ - phi::Array _strides; \ - phi::Array _shifts; \ - phi::Array _sizes; \ - for (size_t idx = 0; idx < N; ++idx) { \ - _strides[idx] = strides[idx]; \ - _shifts[idx] = shifts_data[idx]; \ - _sizes[idx] = sizes[idx]; \ - } \ - RollCudaKernel \ - <<<(numel + PADDLE_CUDA_NUM_THREADS - 1) / PADDLE_CUDA_NUM_THREADS, \ - PADDLE_CUDA_NUM_THREADS, \ - 0, \ - stream>>>(in_data, out_data, numel, _shifts, _strides, _sizes); \ - break; \ +template +void LaunchRollKernel(const Context& dev_ctx, + const T* input, + T* output, + const int rank, + const int64_t numel, + const std::vector shifts, + const std::vector strides, + const std::vector sizes) { + using phi::PADDLE_CUDA_NUM_THREADS; + + phi::Array strides_array; + phi::Array shifts_array; + phi::Array sizes_array; + for (int i = 0; i < rank; ++i) { + strides_array[i] = strides[i]; + shifts_array[i] = shifts[i]; + sizes_array[i] = sizes[i]; } + auto stream = dev_ctx.stream(); + RollCudaKernel + <<<(numel + PADDLE_CUDA_NUM_THREADS - 1) / PADDLE_CUDA_NUM_THREADS, + PADDLE_CUDA_NUM_THREADS, + 0, + stream>>>( + input, output, rank, numel, shifts_array, strides_array, sizes_array); +} + } // namespace phi diff --git a/paddle/phi/kernels/impl/compare_kernel_impl.h b/paddle/phi/kernels/impl/compare_kernel_impl.h index 92e10afc50a42..907bd5a20a104 100644 --- a/paddle/phi/kernels/impl/compare_kernel_impl.h +++ b/paddle/phi/kernels/impl/compare_kernel_impl.h @@ -30,20 +30,35 @@ inline void CompareKernelImpl(const Context& ctx, int axis, DenseTensor* out); +template +inline void InplaceCompareKernelImpl(const Context& ctx, + const DenseTensor& x, + const DenseTensor& y, + int axis, + DenseTensor* out); + template inline void CompareAllKernelImpl(const Context& ctx, const DenseTensor& x, const DenseTensor& y, DenseTensor* out); -#define DEFINE_COMPARE_KERNEL(name, functor, inverse_functor) \ - template \ - void name##Kernel(const Context& ctx, \ - const DenseTensor& x, \ - const DenseTensor& y, \ - DenseTensor* out) { \ - CompareKernelImpl, inverse_functor>( \ - ctx, x, y, -1, out); \ +#define DEFINE_COMPARE_KERNEL(name, functor, inverse_functor) \ + template \ + void name##Kernel(const Context& ctx, \ + const DenseTensor& x, \ + const DenseTensor& y, \ + DenseTensor* out) { \ + if (out->IsSharedWith(x)) { \ + InplaceCompareKernelImpl, inverse_functor>( \ + ctx, x, y, -1, out); \ + } else { \ + CompareKernelImpl, inverse_functor>( \ + ctx, x, y, -1, out); \ + } \ } DEFINE_COMPARE_KERNEL(LessThan, diff --git a/paddle/phi/kernels/impl/fill_kernel_impl.h b/paddle/phi/kernels/impl/fill_kernel_impl.h index 6894204cd06a4..4e8cda48f6dd6 100644 --- a/paddle/phi/kernels/impl/fill_kernel_impl.h +++ b/paddle/phi/kernels/impl/fill_kernel_impl.h @@ -27,9 +27,9 @@ void FillKernel(const Context& dev_ctx, const DenseTensor& x UNUSED, const Scalar& value, DenseTensor* out) { - T fill_var = value.to(); + double fill_var = value.to(); - PADDLE_ENFORCE_EQ(std::isnan(static_cast(fill_var)), + PADDLE_ENFORCE_EQ(std::isnan(fill_var), false, phi::errors::InvalidArgument("fill value should not be NaN," " but received NaN")); @@ -37,7 +37,7 @@ void FillKernel(const Context& dev_ctx, dev_ctx.template Alloc(out); phi::funcs::SetConstant functor; - functor(dev_ctx, out, fill_var); + functor(dev_ctx, out, value.to()); } } // namespace phi diff --git a/paddle/phi/kernels/impl/matmul_kernel_impl.h b/paddle/phi/kernels/impl/matmul_kernel_impl.h index e680e164e623d..b3b9d82d19eec 100644 --- a/paddle/phi/kernels/impl/matmul_kernel_impl.h +++ b/paddle/phi/kernels/impl/matmul_kernel_impl.h @@ -954,14 +954,6 @@ struct MatMulDispatcher { } }; -static phi::Allocator::AllocationPtr GetWorkspace(const phi::GPUContext& ctx, - size_t workspace_size) { - return phi::memory_utils::Alloc( - ctx.GetPlace(), - workspace_size, - phi::Stream(reinterpret_cast(ctx.stream()))); -} - #endif // PADDLE_WITH_CUDA template @@ -979,7 +971,7 @@ void MatMulFunction(const Context& ctx, } template -void MatMulInt8Function(const Context& ctx, +bool MatMulInt8Function(const Context& ctx, const DenseTensor& x, const DenseTensor& y, const std::vector& x_dims, @@ -987,49 +979,245 @@ void MatMulInt8Function(const Context& ctx, DenseTensor* out, bool trans_x, bool trans_y) { - PADDLE_ENFORCE_EQ( - x.dtype(), - DataType::INT8, - phi::errors::InvalidArgument( - "The type of input(x) used in int8 matmul must be (%s) does not " - "match the " - "type of data (%s) currently contained in the container.", - phi::CppTypeToDataType::Type(), - x.dtype())); - PADDLE_ENFORCE_EQ( - y.dtype(), - DataType::INT8, - phi::errors::InvalidArgument( - "The type of input(y) used in int8 matmul must be (%s) does not " - "match the " - "type of data (%s) currently contained in the container.", - phi::CppTypeToDataType::Type(), - x.dtype())); -#if defined(PADDLE_WITH_CUDA) && CUDA_VERSION >= 11020 + return false; +} + +#ifdef PADDLE_WITH_CUDA +template <> +bool inline MatMulInt8Function(const phi::GPUContext& ctx, + const DenseTensor& x, + const DenseTensor& y, + const std::vector& x_dims, + const std::vector& y_dims, + DenseTensor* out, + bool trans_x, + bool trans_y) { + if (x.dtype() != DataType::INT8 || y.dtype() != DataType::INT8) { + return false; + } +#if CUDA_VERSION >= 11060 const int x_ndim = x_dims.size(); const int y_ndim = y_dims.size(); - PADDLE_ENFORCE_EQ( - x_ndim, - 2, - phi::errors::InvalidArgument("[INT8 GEMM] The number of dims of input(x) " - "must be equal to 2 but received %d", - x_ndim)); - PADDLE_ENFORCE_EQ( - y_ndim, - 2, - phi::errors::InvalidArgument("[INT8 GEMM] The number of dims of input(x) " - "must be equal to 2 but received %d", - y_ndim)); - PADDLE_ENFORCE_EQ( + const int8_t* x_data = x.data(); + const int8_t* y_data = y.data(); + using blaslt = phi::funcs::MatmulWithCublasLt; + + phi::funcs::MatmulPlanner matmul_planner( + x_dims, + y_dims, trans_x, - false, - phi::errors::InvalidArgument("[INT8 GEMM] Input(x) must be not " - "transposed to acheive better performance")); - PADDLE_ENFORCE_EQ( trans_y, - true, - phi::errors::InvalidArgument("[INT8 GEMM] Input(y) must be transposed to " - "acheive better performance")); + phi::CppTypeToDataType::Type(), + funcs::MatmulFusedType::kMatmul, + /* bias_data */ nullptr, + /* reserve_data */ nullptr, + /* use_addto */ false, + /* no_exchange */ true); + + if (x_ndim == 1 && y_ndim == 1) { + const int M = x.numel(); + const int N = y.numel(); + PADDLE_ENFORCE_EQ( + M, + N, + phi::errors::InvalidArgument( + "X's numbers must be equal to Y's numbers," + "when X/Y's dims =1. But received X has [%d] elements," + "received Y has [%d] elements", + M, + N)); + if (!(M % 4 == 0)) { + return false; + } + + out->Resize(phi::make_ddim({})); + ctx.template Alloc(out); + blaslt::Run(ctx, + y_data, + x_data, + ctx.template Alloc(out), + 1, + 1, + M, + false, + true, + &matmul_planner); + return true; + } + if (x_ndim == 1) { + const int N = x.numel(); + if (trans_y) { + PADDLE_ENFORCE_EQ( + y_dims[y_ndim - 1], + N, + phi::errors::InvalidArgument("Input(Y) has error dim." + "Y'dims[%d] must be equal to %d" + "But received Y'dims[%d] is %d", + y_ndim - 1, + N, + y_ndim - 1, + y_dims[y_ndim - 1])); + if (!(N % 4 == 0)) { + return false; + } + } else { + PADDLE_ENFORCE_EQ( + y_dims[y_ndim - 2], + N, + phi::errors::InvalidArgument("Input(Y) has error dim." + "Y'dims[%d] must be equal to %d" + "But received Y'dims[%d] is %d", + y_ndim - 2, + N, + y_ndim - 2, + y_dims[y_ndim - 2])); + const int M = y.numel() / N; + if (!(M == 1 || M % 4 == 0)) { + return false; + } + } + std::vector out_dims(y_ndim - 1); + if (trans_y) { + std::copy_n(y_dims.cbegin(), y_ndim - 1, out_dims.begin()); + } else { + std::copy_n(y_dims.cbegin(), y_ndim - 2, out_dims.begin()); + out_dims.back() = y_dims.back(); + } + out->ResizeAndAllocate(phi::make_ddim(out_dims)); + ctx.template Alloc(out); + if (trans_y) { + const int M = y.numel() / N; + blaslt::Run(ctx, + y_data, + x_data, + ctx.template Alloc(out), + M, + 1, + N, + false, + false, + &matmul_planner); + } else { + const int M = y_dims[y_ndim - 1]; + const int batch_size = y.numel() / (M * N); + if (batch_size == 1) { + blaslt::Run(ctx, + y_data, + x_data, + ctx.template Alloc(out), + M, + 1, + N, + true, + false, + &matmul_planner); + } else { + blaslt::RunWithBatch(ctx, + y_data, + x_data, + ctx.template Alloc(out), + M, + 1, + N, + true, + false, + batch_size, + M * N, + 0, + M, + &matmul_planner); + } + } + return true; + } + + if (y_ndim == 1) { + const int N = y.numel(); + if (trans_x) { + PADDLE_ENFORCE_EQ( + x_dims[x_ndim - 2], + N, + phi::errors::InvalidArgument("Input(X) has error dim." + "X'dims[%d] must be equal to %d" + "But received X'dims[%d] is %d", + x_ndim - 2, + N, + x_ndim - 2, + x_dims[x_ndim - 2])); + const int M = x.numel() / N; + if (!((M == 1 || M % 4 == 0))) { + return false; + } + } else { + PADDLE_ENFORCE_EQ( + x_dims[x_ndim - 1], + N, + phi::errors::InvalidArgument("Input(X) has error dim." + "X'dims[%d] must be equal to %d" + "But received X'dims[%d] is %d", + x_ndim - 1, + N, + x_ndim - 1, + x_dims[x_ndim - 1])); + if (N % 4 != 0) { + return false; + } + } + std::vector out_dims(x_ndim - 1); + if (trans_x) { + std::copy_n(x_dims.cbegin(), x_ndim - 2, out_dims.begin()); + out_dims.back() = x_dims.back(); + } else { + std::copy_n(x_dims.cbegin(), x_ndim - 1, out_dims.begin()); + } + out->ResizeAndAllocate(phi::make_ddim(out_dims)); + ctx.template Alloc(out); + + if (trans_x) { + const int M = x_dims[x_ndim - 1]; + const int batch_size = x.numel() / (M * N); + if (batch_size == 1) { + blaslt::Run(ctx, + x_data, + y_data, + ctx.template Alloc(out), + M, + 1, + N, + true, + false, + &matmul_planner); + } else { + blaslt::RunWithBatch(ctx, + x_data, + y_data, + ctx.template Alloc(out), + M, + 1, + N, + true, + false, + batch_size, + M * N, + 0, + M, + &matmul_planner); + } + } else { + const int M = x.numel() / N; + blaslt::Run(ctx, + x_data, + y_data, + ctx.template Alloc(out), + M, + 1, + N, + false, + false, + &matmul_planner); + } + return true; + } const int M = trans_x ? x_dims[x_ndim - 1] : x_dims[x_ndim - 2]; const int K = trans_x ? x_dims[x_ndim - 2] : x_dims[x_ndim - 1]; @@ -1057,27 +1245,186 @@ void MatMulInt8Function(const Context& ctx, y_dims[y_ndim - 2])); } const int N = trans_y ? y_dims[y_ndim - 2] : y_dims[y_ndim - 1]; + const int ndim = (std::max)(x_ndim, y_ndim); + std::vector x_broadcast_dims(ndim); + std::vector y_broadcast_dims(ndim); + std::vector out_broadcast_dims(ndim); + GetBroadcastFromDims(x_ndim - 2, + x_dims.data(), + y_ndim - 2, + y_dims.data(), + x_broadcast_dims.data(), + y_broadcast_dims.data(), + out_broadcast_dims.data()); + out_broadcast_dims[ndim - 2] = M; + out_broadcast_dims[ndim - 1] = N; - size_t workspace_size = static_cast(4) * 1024 * 1024; - phi::Allocator::AllocationPtr workspace = GetWorkspace(ctx, workspace_size); + out->ResizeAndAllocate(phi::make_ddim(out_broadcast_dims)); + ctx.template Alloc(out); - // TODO(wufeisheng): cublaslt_helper is a temp scheme for Int8 GEMM, - // and releted functions need to be integrated into - // phi::funcs::MatmulWithCublasLt - auto cublaslt_helper = CublasLtHelper(M, K, N, ctx.cublaslt_handle()); + const int batch_dim = ndim - 2; + // broadcast message + const bool is_broadcast_dims = + !std::equal(x_broadcast_dims.cbegin(), + x_broadcast_dims.cbegin() + batch_dim, + y_broadcast_dims.cbegin()); - ctx.template Alloc(out); - cublaslt_helper.GEMM(x.data(), - y.data(), - out->data(), - ctx.stream(), - workspace->ptr()); + const std::int64_t x_batch_size = + std::accumulate(x_broadcast_dims.cbegin(), + x_broadcast_dims.cbegin() + batch_dim, + 1LL, + std::multiplies()); + const std::int64_t y_batch_size = + std::accumulate(y_broadcast_dims.cbegin(), + y_broadcast_dims.cbegin() + batch_dim, + 1LL, + std::multiplies()); + const std::int64_t out_batch_size = + std::accumulate(out_broadcast_dims.cbegin(), + out_broadcast_dims.cbegin() + batch_dim, + 1LL, + std::multiplies()); + if (out_batch_size == 0) return true; + if (x_batch_size == 1 && M == 1 && trans_y) { + if (!(K % 4 == 0)) { + return false; + } + } else if (!trans_x && !trans_y) { + if (!(N % 4 == 0 || N == 1) || !(K % 4 == 0) || (M == 1 && N == 1)) { + return false; + } + } else if (!trans_x && trans_y) { + if (!(K % 4 == 0)) { + return false; + } + } else if (trans_x && !trans_y) { + if (!(M % 4 == 0 || M == 1) || !(N % 4 == 0 || N == 1)) { + return false; + } + } else { + if (!(M % 4 == 0 || M == 1) || !(K % 4 == 0)) { + return false; + } + } + if (x_batch_size == 1 && y_batch_size == 1) { + blaslt::Run(ctx, + x_data, + y_data, + ctx.template Alloc(out), + M, + N, + K, + trans_x, + trans_y, + &matmul_planner); + } else if (x_batch_size == 1) { + if (M == 1 && trans_y) { + blaslt::Run(ctx, + y_data, + x_data, + ctx.template Alloc(out), + y_batch_size * N, + 1, + K, + false, + false, + &matmul_planner); + } else { + blaslt::RunWithBatch(ctx, + x_data, + y_data, + ctx.template Alloc(out), + M, + N, + K, + trans_x, + trans_y, + out_batch_size, + 0, + K * N, + M * N, + &matmul_planner); + } + } else if (y_batch_size == 1) { + if (!trans_x) { + blaslt::Run(ctx, + x_data, + y_data, + ctx.template Alloc(out), + x_batch_size * M, + N, + K, + false, + trans_y, + &matmul_planner); + } else { + blaslt::RunWithBatch(ctx, + x_data, + y_data, + ctx.template Alloc(out), + M, + N, + K, + true, + trans_y, + out_batch_size, + M * K, + 0, + M * N, + &matmul_planner); + } + } else if (!is_broadcast_dims) { + blaslt::RunWithBatch(ctx, + x_data, + y_data, + ctx.template Alloc(out), + M, + N, + K, + trans_x, + trans_y, + out_batch_size, + M * K, + K * N, + M * N, + &matmul_planner); + } else { + // in the case, can't use stridedgemm + std::vector x_ptr(out_batch_size); + std::vector y_ptr(out_batch_size); + std::vector out_ptr(out_batch_size); + std::vector index(batch_dim, 0); + for (std::int64_t i = 0; i < out_batch_size; ++i) { + // using the index to get offset + const std::int64_t x_index = + GetIndexMessage(batch_dim, x_broadcast_dims.data(), index.data()); + const std::int64_t y_index = + GetIndexMessage(batch_dim, y_broadcast_dims.data(), index.data()); + + x_ptr[i] = x_data + x_index * M * K; + y_ptr[i] = y_data + y_index * K * N; + out_ptr[i] = ctx.template Alloc(out) + i * M * N; + IndexIncreaseFromDims(batch_dim, out_broadcast_dims.data(), index.data()); + } + blaslt::RunWithBatch(ctx, + x_ptr.data(), + y_ptr.data(), + out_ptr.data(), + M, + N, + K, + trans_x, + trans_y, + out_batch_size, + &matmul_planner); + } + return true; #else - PADDLE_THROW(phi::errors::Unimplemented( - "MatmulInt8 op needs paddle with cuda and cuda version >= 11.2")); + return false; #endif } +#endif template typename std::enable_if::value>::type @@ -1089,6 +1436,11 @@ MatmulJudgeDtypeKernel(const Context& ctx, DenseTensor* out, bool transpose_x, bool transpose_y) { + bool try_matmul_int8 = MatMulInt8Function( + ctx, x, y, x_dims, y_dims, out, transpose_x, transpose_y); + if (try_matmul_int8) { + return; + } auto x_tmp = phi::Cast(ctx, x, phi::DataType::FLOAT32); auto y_tmp = phi::Cast(ctx, y, phi::DataType::FLOAT32); DenseTensor out_tmp; @@ -1135,35 +1487,12 @@ void MatmulKernel(const Context& ctx, } template -void MatmulInt8Kernel(const Context& ctx, - const DenseTensor& x, - const DenseTensor& y, - bool transpose_x, - bool transpose_y, - DenseTensor* out) { - PADDLE_ENFORCE_NE( - phi::product(x.dims()), - 0, - phi::errors::InvalidArgument("The Input(X) dims size must not be equal 0," - " but reviced dims size is 0. ")); - PADDLE_ENFORCE_NE( - phi::product(y.dims()), - 0, - phi::errors::InvalidArgument("The Input(Y) dims size must not be equal 0," - " but reviced dims size is 0. ")); - const std::vector x_dims = vectorize(x.dims()); - const std::vector y_dims = vectorize(y.dims()); - MatMulInt8Function( - ctx, x, y, x_dims, y_dims, out, transpose_x, transpose_y); -} - -template -void MatmulWithFlattenKernel(const Context& dev_ctx, - const DenseTensor& x, - const DenseTensor& y, - int x_num_col_dims, - int y_num_col_dims, - DenseTensor* out) { +void MatmulWithFlattenKernelImpl(const Context& dev_ctx, + const DenseTensor& x, + const DenseTensor& y, + int x_num_col_dims, + int y_num_col_dims, + DenseTensor* out) { const DenseTensor x_matrix = x.dims().size() > 2 ? phi::ReshapeToMatrix(x, x_num_col_dims) : x; const DenseTensor y_matrix = @@ -1183,4 +1512,170 @@ void MatmulWithFlattenKernel(const Context& dev_ctx, } } +#ifdef PADDLE_WITH_CUDA + +template +void MatmulWithFlattenKernelInt8Impl(const Context& dev_ctx, + const DenseTensor& x, + const DenseTensor& y, + int x_num_col_dims, + int y_num_col_dims, + DenseTensor* out) { + PADDLE_ENFORCE_EQ( + x.dtype(), + DataType::INT8, + phi::errors::InvalidArgument( + "The type of input(x) used in int8 mul must be (%s) " + "does not match the " + "type of data (%s) currently contained in the container.", + phi::CppTypeToDataType::Type(), + x.dtype())); + PADDLE_ENFORCE_EQ( + y.dtype(), + DataType::INT8, + phi::errors::InvalidArgument( + "The type of input(y) used in int8 mul must be (%s) " + "does not match the " + "type of data (%s) currently contained in the container.", + phi::CppTypeToDataType::Type(), + y.dtype())); + + const DenseTensor x_matrix = + x.dims().size() > 2 ? phi::ReshapeToMatrix(x, x_num_col_dims) : x; + const DenseTensor y_matrix = + y.dims().size() > 2 ? phi::ReshapeToMatrix(y, y_num_col_dims) : y; + + PADDLE_ENFORCE_EQ( + x_matrix.dims()[1], + y_matrix.dims()[0], + phi::errors::InvalidArgument( + "X's numbers of columns must be equal to Y's numbers of rows." + "But received X has [%d] columns," + "received Y has [%d] rows", + x_matrix.dims()[1], + y_matrix.dims()[0])); + + PADDLE_ENFORCE_EQ((y_matrix.dims()[1] % 4 == 0 || y_matrix.dims()[1] == 1), + true, + phi::errors::InvalidArgument( + "The dimension size N used in int8 mul must be 1" + "or a multiple of 4 does not match the size (%d)" + "currently contained in the container.", + y_matrix.dims()[1])); + PADDLE_ENFORCE_EQ((x_matrix.dims()[1] % 4 == 0), + true, + phi::errors::InvalidArgument( + "The dimension size K used in int8 mul must be a" + "multiple of 4 does not match the size (%d) currently" + "contained in the container.", + x_matrix.dims()[1])); + + dev_ctx.template Alloc(out); + auto z_dim = out->dims(); + if (z_dim.size() != 2) { + out->Resize({x_matrix.dims()[0], y_matrix.dims()[1]}); + } + +#if CUDA_VERSION >= 11060 + using blaslt = phi::funcs::MatmulWithCublasLt; + + const int8_t* x_data = x_matrix.data(); + const int8_t* y_data = y_matrix.data(); + + std::vector x_dims = {x_matrix.dims()[0], x_matrix.dims()[1]}; + std::vector y_dims = {y_matrix.dims()[0], y_matrix.dims()[1]}; + phi::funcs::MatmulPlanner matmul_planner( + x_dims, + y_dims, + false, + false, + phi::CppTypeToDataType::Type(), + funcs::MatmulFusedType::kMatmul, + /* bias_data */ nullptr, + /* reserve_data */ nullptr, + /* use_addto */ false, + /* no_exchange */ true); + + blaslt::Run(dev_ctx, + x_data, + y_data, + dev_ctx.template Alloc(out), + x_matrix.dims()[0], + y_matrix.dims()[1], + x_matrix.dims()[1], + false, + false, + &matmul_planner); + + if (z_dim.size() != 2) { + out->Resize(z_dim); + } +#endif +} +#endif + +#ifdef PADDLE_WITH_CUDA +template +typename std::enable_if::value, + void>::type +DispatchMatmulWithFlattenInt8Kernel(const phi::GPUContext& dev_ctx, + const DenseTensor& x, + const DenseTensor& y, + int x_num_col_dims, + int y_num_col_dims, + DenseTensor* out) { + MatmulWithFlattenKernelInt8Impl( + dev_ctx, x, y, x_num_col_dims, y_num_col_dims, out); +} +#endif + +template +typename std::enable_if::value, + void>::type +DispatchMatmulWithFlattenInt8Kernel(const phi::CPUContext& dev_ctx, + const DenseTensor& x, + const DenseTensor& y, + int x_num_col_dims, + int y_num_col_dims, + DenseTensor* out) { + PADDLE_THROW(phi::errors::Unimplemented( + "MatmulWithFlatten with CPU is NOT implemented " + "yet.")); +} + +template +typename std::enable_if::value, void>::type +DispatchMatmulFlattenKernel(const Context& dev_ctx, + const DenseTensor& x, + const DenseTensor& y, + int x_num_col_dims, + int y_num_col_dims, + DenseTensor* out) { + DispatchMatmulWithFlattenInt8Kernel( + dev_ctx, x, y, x_num_col_dims, y_num_col_dims, out); +} + +template +typename std::enable_if::value, void>::type +DispatchMatmulFlattenKernel(const Context& dev_ctx, + const DenseTensor& x, + const DenseTensor& y, + int x_num_col_dims, + int y_num_col_dims, + DenseTensor* out) { + MatmulWithFlattenKernelImpl( + dev_ctx, x, y, x_num_col_dims, y_num_col_dims, out); +} + +template +void MatmulWithFlattenKernel(const Context& dev_ctx, + const DenseTensor& x, + const DenseTensor& y, + int x_num_col_dims, + int y_num_col_dims, + DenseTensor* out) { + DispatchMatmulFlattenKernel( + dev_ctx, x, y, x_num_col_dims, y_num_col_dims, out); +} + } // namespace phi diff --git a/paddle/phi/kernels/kps/compare_kernel.cu b/paddle/phi/kernels/kps/compare_kernel.cu index 728298be2cd83..545a9df2961bf 100644 --- a/paddle/phi/kernels/kps/compare_kernel.cu +++ b/paddle/phi/kernels/kps/compare_kernel.cu @@ -52,16 +52,27 @@ inline void CompareKernelImpl(const Context& ctx, const DenseTensor& y, int axis, DenseTensor* out) { - if (!out->IsSharedWith(x)) { - ctx.template Alloc(out); - } + ctx.template Alloc(out); std::vector ins{&x, &y}; std::vector outs{out}; - if (!out->IsSharedWith(x)) { - funcs::BroadcastKernel(ctx, ins, &outs, Functor(), axis); - } else { - funcs::BroadcastKernel(ctx, ins, &outs, Functor(), axis); - } + funcs::BroadcastKernel(ctx, ins, &outs, Functor(), axis); +} + +template +inline void InplaceCompareKernelImpl(const Context& ctx, + const DenseTensor& x, + const DenseTensor& y, + int axis, + DenseTensor* out) { + auto x_origin = x; + ctx.template Alloc(out); + out->set_type(phi::DataType::BOOL); + std::vector ins{&x_origin, &y}; + std::vector outs{out}; + funcs::BroadcastKernel(ctx, ins, &outs, Functor(), axis); } #ifndef PADDLE_WITH_XPU_KP @@ -134,18 +145,21 @@ PD_REGISTER_KERNEL(equal_all, kernel->OutputAt(0).SetDataType(phi::DataType::BOOL); } -#define PD_REGISTER_COMPARE_KERNEL(name, func) \ - PD_REGISTER_KERNEL(name, \ - KPS, \ - ALL_LAYOUT, \ - phi::func##Kernel, \ - bool, \ - int16_t, \ - int, \ - int64_t, \ - float, \ - double, \ - phi::dtype::float16) {} +#define PD_REGISTER_COMPARE_KERNEL(name, func) \ + PD_REGISTER_KERNEL(name, \ + KPS, \ + ALL_LAYOUT, \ + phi::func##Kernel, \ + bool, \ + int16_t, \ + int, \ + int64_t, \ + float, \ + double, \ + phi::dtype::float16, \ + phi::dtype::bfloat16) { \ + kernel->OutputAt(0).SetDataType(phi::DataType::BOOL); \ + } PD_REGISTER_COMPARE_KERNEL(less_than, LessThan) PD_REGISTER_COMPARE_KERNEL(less_equal, LessEqual) diff --git a/paddle/phi/kernels/kps/elementwise_add_kernel.cu b/paddle/phi/kernels/kps/elementwise_add_kernel.cu deleted file mode 100644 index b3fe46a1cd310..0000000000000 --- a/paddle/phi/kernels/kps/elementwise_add_kernel.cu +++ /dev/null @@ -1,135 +0,0 @@ -// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#include "paddle/phi/backends/gpu/gpu_context.h" -#ifndef PADDLE_WITH_XPU_KP -#include "paddle/phi/common/complex.h" -#include "paddle/phi/common/float16.h" -#endif -#include "paddle/phi/core/kernel_registry.h" -#include "paddle/phi/kernels/elementwise_add_kernel.h" -#include "paddle/phi/kernels/impl/elementwise_kernel_impl.h" - -namespace phi { - -template -void AddCudaFunctor(const Context& dev_ctx, - const DenseTensor& x, - const DenseTensor& y, - int axis, - DenseTensor* out) { - std::vector inputs; - inputs.reserve(2); - std::vector outputs; - outputs.reserve(1); - inputs.emplace_back(&x); - inputs.emplace_back(&y); - outputs.emplace_back(out); - dev_ctx.template Alloc(out); - funcs::BroadcastKernel( - dev_ctx, inputs, &outputs, funcs::AddFunctor(), axis); -} - -template -void Float32Bfloat16OrFloat16AddCudaFunctor(const Context& dev_ctx, - const DenseTensor& x, - const DenseTensor& y, - DenseTensor* out) { - std::vector inputs; - inputs.reserve(2); - std::vector outputs; - outputs.reserve(1); - inputs.emplace_back(&x); - inputs.emplace_back(&y); - outputs.emplace_back(out); - if (y.dtype() == phi::DataType::BFLOAT16) { - funcs::ElementwiseKernel( - dev_ctx, inputs, &outputs, funcs::Float32Bfloat16AddFunctor()); - } else if (y.dtype() == phi::DataType::FLOAT16) { - funcs::ElementwiseKernel( - dev_ctx, inputs, &outputs, funcs::Float32Float16AddFunctor()); - } else { - PADDLE_THROW(phi::errors::InvalidArgument( - "Unsupport x dtype:%s, y dtype:%s for add(x, y) operation", - phi::DataTypeToString(x.type()), - phi::DataTypeToString(y.type()))); - } -} - -template -void AddKernel(const Context& dev_ctx, - const DenseTensor& x, - const DenseTensor& y, - DenseTensor* out) { -#ifdef PADDLE_WITH_CUDA - if (x.dtype() == phi::DataType::FLOAT32 && - (y.dtype() == phi::DataType::BFLOAT16 || - y.dtype() == phi::DataType::FLOAT16)) { - using Type = DataTypeToCppType::type; - Float32Bfloat16OrFloat16AddCudaFunctor(dev_ctx, x, y, out); - } else { -#endif - AddCudaFunctor(dev_ctx, x, y, -1, out); -#ifdef PADDLE_WITH_CUDA - } -#endif -} - -template -void GradAddKernel(const Context& dev_ctx, - const DenseTensor& x, - const DenseTensor& y, - DenseTensor* out) { - AddCudaFunctor(dev_ctx, x, y, -1, out); -} - -} // namespace phi - -#ifdef PADDLE_WITH_XPU_KP -PD_REGISTER_KERNEL(add, KPS, ALL_LAYOUT, phi::AddKernel, float) {} -#else - -using float16 = phi::dtype::float16; -using bfloat16 = phi::dtype::bfloat16; -using complex64 = ::phi::dtype::complex; -using complex128 = ::phi::dtype::complex; - -PD_REGISTER_KERNEL(add, - KPS, - ALL_LAYOUT, - phi::AddKernel, - float, - double, - int16_t, - int, - int64_t, - phi::dtype::float16, - phi::dtype::bfloat16, - complex64, - complex128) {} - -PD_REGISTER_KERNEL(grad_add, - KPS, - ALL_LAYOUT, - phi::GradAddKernel, - float, - double, - int16_t, - int, - int64_t, - phi::dtype::float16, - phi::dtype::bfloat16, - complex64, - complex128) {} -#endif diff --git a/paddle/phi/kernels/kps/elementwise_divide_kernel.cu b/paddle/phi/kernels/kps/elementwise_divide_kernel.cu deleted file mode 100644 index 1f4e4ad05adde..0000000000000 --- a/paddle/phi/kernels/kps/elementwise_divide_kernel.cu +++ /dev/null @@ -1,66 +0,0 @@ -// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#include "paddle/phi/backends/gpu/gpu_context.h" -#ifndef PADDLE_WITH_XPU_KP -#include "paddle/phi/common/complex.h" -#include "paddle/phi/common/float16.h" -#endif -#include "paddle/phi/core/kernel_registry.h" -#include "paddle/phi/kernels/impl/elementwise_kernel_impl.h" - -namespace phi { - -template -void DivideKernel(const Context& dev_ctx, - const DenseTensor& x, - const DenseTensor& y, - DenseTensor* out) { - std::vector inputs; - inputs.reserve(2); - std::vector outputs; - outputs.reserve(1); - inputs.emplace_back(&x); - inputs.emplace_back(&y); - outputs.emplace_back(out); - dev_ctx.template Alloc(out); - funcs::BroadcastKernel( - dev_ctx, inputs, &outputs, funcs::DivideFunctor(), -1); -} - -} // namespace phi - -#ifdef PADDLE_WITH_XPU_KP -PD_REGISTER_KERNEL(divide, KPS, ALL_LAYOUT, phi::DivideKernel, float) {} -#else - -using float16 = phi::dtype::float16; -using bfloat16 = phi::dtype::bfloat16; -using complex64 = ::phi::dtype::complex; -using complex128 = ::phi::dtype::complex; - -PD_REGISTER_KERNEL(divide, - KPS, - ALL_LAYOUT, - phi::DivideKernel, - float, - double, - int, - int64_t, - float16, - bfloat16, - complex64, - complex128) {} - -#endif diff --git a/paddle/phi/kernels/kps/elementwise_kernel.cu b/paddle/phi/kernels/kps/elementwise_kernel.cu index e88714c370be9..e4eacda7d98fb 100644 --- a/paddle/phi/kernels/kps/elementwise_kernel.cu +++ b/paddle/phi/kernels/kps/elementwise_kernel.cu @@ -18,11 +18,129 @@ #include "paddle/phi/common/float16.h" #endif #include "paddle/phi/core/kernel_registry.h" +#include "paddle/phi/kernels/elementwise_add_kernel.h" #include "paddle/phi/kernels/impl/elementwise_kernel_impl.h" #include "paddle/phi/kernels/legacy/elementwise_kernel.h" namespace phi { +template +void SubtractKernel(const Context& dev_ctx, + const DenseTensor& x, + const DenseTensor& y, + DenseTensor* out) { + std::vector inputs; + inputs.reserve(2); + std::vector outputs; + outputs.reserve(1); + inputs.emplace_back(&x); + inputs.emplace_back(&y); + outputs.emplace_back(out); + dev_ctx.template Alloc(out); + funcs::BroadcastKernel( + dev_ctx, inputs, &outputs, funcs::SubtractFunctor(), -1); +} + +template +void MultiplyKernel(const Context& dev_ctx, + const DenseTensor& x, + const DenseTensor& y, + DenseTensor* out) { + std::vector inputs; + inputs.reserve(2); + std::vector outputs; + outputs.reserve(1); + inputs.emplace_back(&x); + inputs.emplace_back(&y); + outputs.emplace_back(out); + dev_ctx.template Alloc(out); + funcs::BroadcastKernel( + dev_ctx, inputs, &outputs, funcs::MultiplyFunctor(), -1); +} + +template +void DivideKernel(const Context& dev_ctx, + const DenseTensor& x, + const DenseTensor& y, + DenseTensor* out) { + std::vector inputs; + inputs.reserve(2); + std::vector outputs; + outputs.reserve(1); + inputs.emplace_back(&x); + inputs.emplace_back(&y); + outputs.emplace_back(out); + dev_ctx.template Alloc(out); + funcs::BroadcastKernel( + dev_ctx, inputs, &outputs, funcs::DivideFunctor(), -1); +} + +template +void AddKernelImpl(const Context& dev_ctx, + const DenseTensor& x, + const DenseTensor& y, + int axis, + DenseTensor* out) { + std::vector inputs = {&x, &y}; + std::vector outputs = {out}; + dev_ctx.template Alloc(out); + funcs::BroadcastKernel( + dev_ctx, inputs, &outputs, funcs::AddFunctor(), axis); +} + +template +void MultiPrecisionAddKernelImpl(const Context& dev_ctx, + const DenseTensor& x, + const DenseTensor& y, + DenseTensor* out) { + std::vector inputs = {&x, &y}; + std::vector outputs = {out}; + if (y.dtype() == phi::DataType::BFLOAT16) { + funcs::ElementwiseKernel( + dev_ctx, + inputs, + &outputs, + funcs::MultiPrecisionAddFunctor()); + } else if (y.dtype() == phi::DataType::FLOAT16) { + funcs::ElementwiseKernel( + dev_ctx, + inputs, + &outputs, + funcs::MultiPrecisionAddFunctor()); + } else { + PADDLE_THROW(phi::errors::InvalidArgument( + "Unsupport x dtype:%s, y dtype:%s for add(x, y) operation", + phi::DataTypeToString(x.type()), + phi::DataTypeToString(y.type()))); + } +} + +template +void AddKernel(const Context& dev_ctx, + const DenseTensor& x, + const DenseTensor& y, + DenseTensor* out) { +#ifdef PADDLE_WITH_CUDA + if (x.dtype() == phi::DataType::FLOAT32 && + (y.dtype() == phi::DataType::BFLOAT16 || + y.dtype() == phi::DataType::FLOAT16)) { + MultiPrecisionAddKernelImpl(dev_ctx, x, y, out); + } else { +#endif + AddKernelImpl(dev_ctx, x, y, -1, out); +#ifdef PADDLE_WITH_CUDA + } +#endif +} + +template +void GradAddKernel(const Context& dev_ctx, + const DenseTensor& x, + const DenseTensor& y, + DenseTensor* out) { + AddKernelImpl(dev_ctx, x, y, -1, out); +} + template void MaximumKernel(const Context& dev_ctx, const DenseTensor& x, @@ -58,6 +176,7 @@ void FloorDivideKernel(const Context& dev_ctx, int axis = -1; FloorDivideRawKernel(dev_ctx, x, y, axis, out); } + // Create the definition of Heaviside template void HeavisideKernel(const Context& dev_ctx, @@ -148,6 +267,10 @@ PD_REGISTER_KERNEL(elementwise_pow, #ifdef PADDLE_WITH_XPU_KP PD_REGISTER_KERNEL(maximum, KPS, ALL_LAYOUT, phi::MaximumKernel, float) {} PD_REGISTER_KERNEL(minimum, KPS, ALL_LAYOUT, phi::MinimumKernel, float) {} +PD_REGISTER_KERNEL(divide, KPS, ALL_LAYOUT, phi::DivideKernel, float) {} +PD_REGISTER_KERNEL(multiply, KPS, ALL_LAYOUT, phi::MultiplyKernel, float) {} +PD_REGISTER_KERNEL(add, KPS, ALL_LAYOUT, phi::AddKernel, float) {} +PD_REGISTER_KERNEL(subtract, KPS, ALL_LAYOUT, phi::SubtractKernel, float) {} PD_REGISTER_KERNEL(floor_divide, KPS, ALL_LAYOUT, phi::FloorDivideKernel, int) { } PD_REGISTER_KERNEL( @@ -191,4 +314,74 @@ PD_REGISTER_KERNEL(heaviside, float16, bfloat16, int64_t) {} + +PD_REGISTER_KERNEL(add, + KPS, + ALL_LAYOUT, + phi::AddKernel, + float, + double, + int16_t, + int, + int64_t, + phi::dtype::float16, + phi::dtype::bfloat16, + complex64, + complex128) {} + +PD_REGISTER_KERNEL(grad_add, + KPS, + ALL_LAYOUT, + phi::GradAddKernel, + float, + double, + int16_t, + int, + int64_t, + phi::dtype::float16, + phi::dtype::bfloat16, + complex64, + complex128) {} + +PD_REGISTER_KERNEL(divide, + KPS, + ALL_LAYOUT, + phi::DivideKernel, + float, + double, + int, + int64_t, + float16, + bfloat16, + complex64, + complex128) {} + +PD_REGISTER_KERNEL(multiply, + KPS, + ALL_LAYOUT, + phi::MultiplyKernel, + float, + double, + int, + int64_t, + bool, + float16, + complex64, + complex128, + bfloat16) {} + +PD_REGISTER_KERNEL(subtract, + KPS, + ALL_LAYOUT, + phi::SubtractKernel, + float, + double, + int16_t, + int, + int64_t, + float16, + bfloat16, + complex64, + complex128) {} + #endif diff --git a/paddle/phi/kernels/kps/elementwise_multiply_kernel.cu b/paddle/phi/kernels/kps/elementwise_multiply_kernel.cu deleted file mode 100644 index 120f49e0d0b7d..0000000000000 --- a/paddle/phi/kernels/kps/elementwise_multiply_kernel.cu +++ /dev/null @@ -1,67 +0,0 @@ -// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#include "paddle/phi/backends/gpu/gpu_context.h" -#ifndef PADDLE_WITH_XPU_KP -#include "paddle/phi/common/complex.h" -#include "paddle/phi/common/float16.h" -#endif -#include "paddle/phi/core/kernel_registry.h" -#include "paddle/phi/kernels/impl/elementwise_kernel_impl.h" - -namespace phi { - -template -void MultiplyKernel(const Context& dev_ctx, - const DenseTensor& x, - const DenseTensor& y, - DenseTensor* out) { - std::vector inputs; - inputs.reserve(2); - std::vector outputs; - outputs.reserve(1); - inputs.emplace_back(&x); - inputs.emplace_back(&y); - outputs.emplace_back(out); - dev_ctx.template Alloc(out); - funcs::BroadcastKernel( - dev_ctx, inputs, &outputs, funcs::MultiplyFunctor(), -1); -} - -} // namespace phi - -#ifdef PADDLE_WITH_XPU_KP -PD_REGISTER_KERNEL(multiply, KPS, ALL_LAYOUT, phi::MultiplyKernel, float) {} -#else - -using float16 = phi::dtype::float16; -using bfloat16 = phi::dtype::bfloat16; -using complex64 = ::phi::dtype::complex; -using complex128 = ::phi::dtype::complex; - -PD_REGISTER_KERNEL(multiply, - KPS, - ALL_LAYOUT, - phi::MultiplyKernel, - float, - double, - int, - int64_t, - bool, - float16, - complex64, - complex128, - bfloat16) {} - -#endif diff --git a/paddle/phi/kernels/kps/elementwise_subtract_kernel.cu b/paddle/phi/kernels/kps/elementwise_subtract_kernel.cu deleted file mode 100644 index 4f6015990e216..0000000000000 --- a/paddle/phi/kernels/kps/elementwise_subtract_kernel.cu +++ /dev/null @@ -1,67 +0,0 @@ -// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#include "paddle/phi/backends/gpu/gpu_context.h" -#ifndef PADDLE_WITH_XPU_KP -#include "paddle/phi/common/complex.h" -#include "paddle/phi/common/float16.h" -#endif -#include "paddle/phi/core/kernel_registry.h" -#include "paddle/phi/kernels/impl/elementwise_kernel_impl.h" - -namespace phi { - -template -void SubtractKernel(const Context& dev_ctx, - const DenseTensor& x, - const DenseTensor& y, - DenseTensor* out) { - std::vector inputs; - inputs.reserve(2); - std::vector outputs; - outputs.reserve(1); - inputs.emplace_back(&x); - inputs.emplace_back(&y); - outputs.emplace_back(out); - dev_ctx.template Alloc(out); - funcs::BroadcastKernel( - dev_ctx, inputs, &outputs, funcs::SubtractFunctor(), -1); -} - -} // namespace phi - -#ifdef PADDLE_WITH_XPU_KP -PD_REGISTER_KERNEL(subtract, KPS, ALL_LAYOUT, phi::SubtractKernel, float) {} -#else - -using float16 = phi::dtype::float16; -using bfloat16 = phi::dtype::bfloat16; -using complex64 = ::phi::dtype::complex; -using complex128 = ::phi::dtype::complex; - -PD_REGISTER_KERNEL(subtract, - KPS, - ALL_LAYOUT, - phi::SubtractKernel, - float, - double, - int16_t, - int, - int64_t, - float16, - bfloat16, - complex64, - complex128) {} - -#endif diff --git a/paddle/phi/kernels/kps/logical_kernel.cu b/paddle/phi/kernels/kps/logical_kernel.cu index f7c390e65d0ff..5e62ab2684f7a 100644 --- a/paddle/phi/kernels/kps/logical_kernel.cu +++ b/paddle/phi/kernels/kps/logical_kernel.cu @@ -25,24 +25,45 @@ namespace phi { -#define DEFINE_LOGICAL_BINARY_KERNEL(type) \ - template \ - void Logical##type##Kernel(const Context& dev_ctx, \ - const DenseTensor& x, \ - const DenseTensor& y, \ - DenseTensor* out) { \ - if (!out->IsSharedWith(x)) { \ - dev_ctx.template Alloc(out); \ - } \ - \ - funcs::Logical##type##Functor binary_func; \ - std::vector ins = {&x, &y}; \ - std::vector outs = {out}; \ - if (!out->IsSharedWith(x)) { \ - funcs::BroadcastKernel(dev_ctx, ins, &outs, binary_func); \ - } else { \ - funcs::BroadcastKernel(dev_ctx, ins, &outs, binary_func); \ - } \ +template +void LogicalKernelImpl(const Context& dev_ctx, + const DenseTensor& x, + const DenseTensor& y, + DenseTensor* out) { + dev_ctx.template Alloc(out); + Functor binary_func; + std::vector ins = {&x, &y}; + std::vector outs = {out}; + funcs::BroadcastKernel(dev_ctx, ins, &outs, binary_func); +} + +template +void InplaceLogicalKernelImpl(const Context& dev_ctx, + const DenseTensor& x, + const DenseTensor& y, + DenseTensor* out) { + auto x_origin = x; + dev_ctx.template Alloc(out); + out->set_type(phi::DataType::BOOL); + Functor binary_func; + std::vector ins = {&x_origin, &y}; + std::vector outs = {out}; + funcs::BroadcastKernel(dev_ctx, ins, &outs, binary_func); +} + +#define DEFINE_LOGICAL_BINARY_KERNEL(type) \ + template \ + void Logical##type##Kernel(const Context& dev_ctx, \ + const DenseTensor& x, \ + const DenseTensor& y, \ + DenseTensor* out) { \ + if (out->IsSharedWith(x)) { \ + InplaceLogicalKernelImpl>( \ + dev_ctx, x, y, out); \ + } else { \ + LogicalKernelImpl>( \ + dev_ctx, x, y, out); \ + } \ } DEFINE_LOGICAL_BINARY_KERNEL(And) @@ -56,14 +77,18 @@ void LogicalNotKernel(const Context& dev_ctx, DenseTensor* out) { if (!out->IsSharedWith(x)) { dev_ctx.template Alloc(out); - } - funcs::LogicalNotFunctor unary_func; - std::vector ins = {&x}; - std::vector outs = {out}; - if (!out->IsSharedWith(x)) { + funcs::LogicalNotFunctor unary_func; + std::vector ins = {&x}; + std::vector outs = {out}; funcs::BroadcastKernel(dev_ctx, ins, &outs, unary_func); } else { - funcs::BroadcastKernel(dev_ctx, ins, &outs, unary_func); + auto x_origin = x; + out->set_type(phi::DataType::BOOL); + dev_ctx.template Alloc(out); + funcs::LogicalNotFunctor unary_func; + std::vector ins = {&x_origin}; + std::vector outs = {out}; + funcs::BroadcastKernel(dev_ctx, ins, &outs, unary_func); } } @@ -99,7 +124,9 @@ PD_REGISTER_KERNEL(logical_xor, KPS, ALL_LAYOUT, phi::LogicalXorKernel, int) { int8_t, \ phi::dtype::complex, \ phi::dtype::complex, \ - int16_t) {} + int16_t) { \ + kernel->OutputAt(0).SetDataType(phi::DataType::BOOL); \ + } REGISTER_LOGICAL_CUDA_KERNEL(logical_and, And) REGISTER_LOGICAL_CUDA_KERNEL(logical_or, Or) diff --git a/paddle/phi/kernels/xpu/adamw_kernel.cc b/paddle/phi/kernels/xpu/adamw_kernel.cc index a74fd2aa9bd4e..7a732878ff64c 100644 --- a/paddle/phi/kernels/xpu/adamw_kernel.cc +++ b/paddle/phi/kernels/xpu/adamw_kernel.cc @@ -20,11 +20,30 @@ #include "paddle/phi/backends/xpu/enforce_xpu.h" #include "paddle/phi/backends/xpu/xpu_context.h" +#include "paddle/phi/common/memory_utils.h" #include "paddle/phi/core/kernel_registry.h" #include "paddle/phi/core/tensor_utils.h" namespace phi { +template +float GetAbsMax(const Context& dev_ctx, + const float* input, + float* buffer_xpu, + int64_t numel) { + float buffer_cpu[6]; + // int findmax(Context* ctx, const T* x, float* maxptr, int64_t len); + int r = xpu::findmax(dev_ctx.x_context(), input, buffer_xpu, numel); + PADDLE_ENFORCE_XDNN_SUCCESS(r, "findmax"); + memory_utils::Copy(CPUPlace(), + static_cast(buffer_cpu), + dev_ctx.GetPlace(), + static_cast(buffer_xpu), + sizeof(float) * 6); + float* max_value = std::max_element(buffer_cpu, buffer_cpu + 6); + return *max_value; +} + template void AdamwDenseKernel(const Context& dev_ctx, const DenseTensor& param, @@ -52,6 +71,98 @@ void AdamwDenseKernel(const Context& dev_ctx, DenseTensor* beta1_pow_out, DenseTensor* beta2_pow_out, DenseTensor* master_param_outs) { + // check moment_dtype + auto moment1_dtype = moment1.dtype(); + auto moment2_dtype = moment2.dtype(); + PADDLE_ENFORCE_EQ(moment1_dtype, + moment1_out->dtype(), + errors::InvalidArgument( + "moment1.dtype does not match moment1_out->dtype")); + PADDLE_ENFORCE_EQ(moment2_dtype, + moment2_out->dtype(), + errors::InvalidArgument( + "moment2.dtype does not match moment2_out->dtype")); + PADDLE_ENFORCE_EQ( + moment1_dtype, + moment2_dtype, + errors::InvalidArgument("moment1.dtype does not match moment2.dtype")); + + bool moment_in_fp16 = false; + if (moment1_dtype == phi::DataType::FLOAT16) { + moment_in_fp16 = true; + } else { + PADDLE_ENFORCE_EQ( + moment1_dtype, + phi::DataType::FLOAT32, + errors::InvalidArgument("moment1.dtype is neither fp32 nor fp16")); + } + + float* moment1_input_for_xdnn = nullptr; + float* moment2_input_for_xdnn = nullptr; + float* moment1_output_for_xdnn = nullptr; + float* moment2_output_for_xdnn = nullptr; + + xpu::ctx_guard RAII_GUARD(dev_ctx.x_context()); + if (moment_in_fp16) { + // allocate temp buffer on XPU + moment1_input_for_xdnn = RAII_GUARD.alloc_l3_or_gm(moment1.numel()); + PADDLE_ENFORCE_XDNN_NOT_NULL(moment1_input_for_xdnn); + moment2_input_for_xdnn = RAII_GUARD.alloc_l3_or_gm(moment2.numel()); + PADDLE_ENFORCE_XDNN_NOT_NULL(moment2_input_for_xdnn); + moment1_output_for_xdnn = + RAII_GUARD.alloc_l3_or_gm(moment1_out->numel()); + PADDLE_ENFORCE_XDNN_NOT_NULL(moment1_output_for_xdnn); + moment2_output_for_xdnn = + RAII_GUARD.alloc_l3_or_gm(moment2_out->numel()); + PADDLE_ENFORCE_XDNN_NOT_NULL(moment2_output_for_xdnn); + + int r = 0; + using XPUType16 = typename XPUTypeTrait::Type; + + // cast moment1 and moment2, from fp16 to fp32 + // int cast(Context* ctx, const TX* x, TY* y, int64_t len); + r = xpu::cast( + dev_ctx.x_context(), + reinterpret_cast( + moment1.template data()), + moment1_input_for_xdnn, + moment1.numel()); + PADDLE_ENFORCE_XDNN_SUCCESS(r, "cast moment1 from fp16 to float"); + r = xpu::cast( + dev_ctx.x_context(), + reinterpret_cast( + moment2.template data()), + moment2_input_for_xdnn, + moment2.numel()); + PADDLE_ENFORCE_XDNN_SUCCESS(r, "cast moment2 from fp16 to float"); + + // de-scale using meta's scale_value + // int scale(Context* ctx, const T* x, T* y, int64_t len, bool + // bias_after_scale, float _scale, float _bias); + phi::DenseTensorMeta moment1_meta = moment1.meta(); + if (moment1_meta.scale_value > 0) { + r = xpu::scale(dev_ctx.x_context(), + moment1_input_for_xdnn, + moment1_input_for_xdnn, + moment1.numel(), + false, + 1.0f / moment1_meta.scale_value, + 0.0f); + PADDLE_ENFORCE_XDNN_SUCCESS(r, "de-scale for moment1"); + } + phi::DenseTensorMeta moment2_meta = moment2.meta(); + if (moment2_meta.scale_value > 0) { + r = xpu::scale(dev_ctx.x_context(), + moment2_input_for_xdnn, + moment2_input_for_xdnn, + moment2.numel(), + false, + 1.0f / moment2_meta.scale_value, + 0.0f); + PADDLE_ENFORCE_XDNN_SUCCESS(r, "de-scale for moment2"); + } + } + using XPUType = typename XPUTypeTrait::Type; bool skip_update_ = false; if (skip_update.is_initialized()) { @@ -94,7 +205,7 @@ void AdamwDenseKernel(const Context& dev_ctx, if (!with_decay) { coeff = static_cast(0.0); } - xpu::ctx_guard RAII_GUARD(dev_ctx.x_context()); + float* new_lr = RAII_GUARD.alloc_l3_or_gm(learning_rate.numel()); PADDLE_ENFORCE_XDNN_NOT_NULL(new_lr); int r = 0; @@ -107,17 +218,23 @@ void AdamwDenseKernel(const Context& dev_ctx, 0.0f); PADDLE_ENFORCE_XDNN_SUCCESS(r, "scale"); + // int adamw(Context* ctx, const T* g, const float* mom1, const float* mom2, + // const T* param, const float* beta1_pow, const float* beta2_pow, const + // float* lr, float* moment1_out, float* moment2_out, T* param_out, float + // beta1, float beta2, float epsilon, float coeff, int64_t n); r = xpu::adamw( dev_ctx.x_context(), reinterpret_cast(grad.template data()), - moment1.template data(), - moment2.template data(), + moment_in_fp16 ? moment1_input_for_xdnn : moment1.template data(), + moment_in_fp16 ? moment2_input_for_xdnn : moment2.template data(), reinterpret_cast(param.template data()), beta1_pow_ptr, beta2_pow_ptr, new_lr, - dev_ctx.template Alloc(moment1_out), - dev_ctx.template Alloc(moment2_out), + moment_in_fp16 ? moment1_output_for_xdnn + : dev_ctx.template Alloc(moment1_out), + moment_in_fp16 ? moment2_output_for_xdnn + : dev_ctx.template Alloc(moment2_out), reinterpret_cast(dev_ctx.template Alloc(param_out)), beta1_, beta2_, @@ -126,6 +243,75 @@ void AdamwDenseKernel(const Context& dev_ctx, param.numel()); PADDLE_ENFORCE_XDNN_SUCCESS(r, "adamw"); + if (moment_in_fp16) { + int r = 0; + using XPUType16 = typename XPUTypeTrait::Type; + + // findmax and calculate scale_value for moment1 and moment2 + float* buffer_for_findmax = RAII_GUARD.alloc_l3_or_gm(6); + + // for moment1 + float moment1_max = GetAbsMax(dev_ctx, + moment1_output_for_xdnn, + buffer_for_findmax, + moment1_out->numel()); + float moment1_scale_value = 65504.0f / moment1_max / 2.0f; + // int scale(Context* ctx, const T* x, T* y, int64_t len, bool + // bias_after_scale, float _scale, float _bias); + r = xpu::scale(dev_ctx.x_context(), + moment1_output_for_xdnn, + moment1_output_for_xdnn, + moment1_out->numel(), + false, + moment1_scale_value, + 0.0f); + PADDLE_ENFORCE_XDNN_SUCCESS( + r, "scale before convert to fp16, for moment1_output_for_xdnn"); + // write to meta info + phi::DenseTensorMeta moment1_out_meta = moment1_out->meta(); + moment1_out_meta.scale_value = moment1_scale_value; + moment1_out->set_meta(moment1_out_meta); + + // for moment2 + float moment2_max = GetAbsMax(dev_ctx, + moment2_output_for_xdnn, + buffer_for_findmax, + moment2_out->numel()); + float moment2_scale_value = 65504.0f / moment2_max / 2.0f; + // int scale(Context* ctx, const T* x, T* y, int64_t len, bool + // bias_after_scale, float _scale, float _bias); + r = xpu::scale(dev_ctx.x_context(), + moment2_output_for_xdnn, + moment2_output_for_xdnn, + moment2_out->numel(), + false, + moment2_scale_value, + 0.0f); + PADDLE_ENFORCE_XDNN_SUCCESS( + r, "scale before convert to fp16, for moment2_output_for_xdnn"); + // write to meta info + phi::DenseTensorMeta moment2_out_meta = moment2_out->meta(); + moment2_out_meta.scale_value = moment2_scale_value; + moment2_out->set_meta(moment2_out_meta); + + // cast moment1 and moment2 output, from fp32 to fp16 + // int cast(Context* ctx, const TX* x, TY* y, int64_t len); + r = xpu::cast( + dev_ctx.x_context(), + moment1_output_for_xdnn, + reinterpret_cast( + dev_ctx.template Alloc(moment1_out)), + moment1.numel()); + PADDLE_ENFORCE_XDNN_SUCCESS(r, "cast moment1_out from float to fp16"); + r = xpu::cast( + dev_ctx.x_context(), + moment2_output_for_xdnn, + reinterpret_cast( + dev_ctx.template Alloc(moment2_out)), + moment2.numel()); + PADDLE_ENFORCE_XDNN_SUCCESS(r, "cast moment2_out from float to fp16"); + } + if (!use_global_beta_pow) { // update in cpu if (beta1_pow.place() == CPUPlace() && beta2_pow.place() == CPUPlace()) { diff --git a/paddle/ir/CMakeLists.txt b/paddle/pir/CMakeLists.txt similarity index 91% rename from paddle/ir/CMakeLists.txt rename to paddle/pir/CMakeLists.txt index 5a778466b4c19..1f87a16ff36a6 100644 --- a/paddle/ir/CMakeLists.txt +++ b/paddle/pir/CMakeLists.txt @@ -43,31 +43,31 @@ add_subdirectory(dialect) if(WIN32) if(WITH_SHARED_IR) set(IR_NAME - ir.dll + pir.dll CACHE INTERNAL "" FORCE) else() set(IR_NAME - ir.lib + pir.lib CACHE INTERNAL "" FORCE) endif() elseif(APPLE) if(WITH_SHARED_IR) set(IR_NAME - libir.dylib + libpir.dylib CACHE INTERNAL "" FORCE) else() set(IR_NAME - libir.a + libpir.a CACHE INTERNAL "" FORCE) endif() else() if(WITH_SHARED_IR) set(IR_NAME - libir.so + libpir.so CACHE INTERNAL "" FORCE) else() set(IR_NAME - libir.a + libpir.a CACHE INTERNAL "" FORCE) endif() endif() @@ -78,7 +78,7 @@ set(IR_LIB get_property(ir_modules GLOBAL PROPERTY IR_MODULES) if(WITH_SHARED_IR) - add_library(ir SHARED ${ir_modules}) + add_library(pir SHARED ${ir_modules}) else() - add_library(ir STATIC ${ir_modules}) + add_library(pir STATIC ${ir_modules}) endif() diff --git a/paddle/pir/core/CMakeLists.txt b/paddle/pir/core/CMakeLists.txt new file mode 100644 index 0000000000000..0fffc4285e376 --- /dev/null +++ b/paddle/pir/core/CMakeLists.txt @@ -0,0 +1,9 @@ +set(NEWIR_SOURCE_DIR "${PADDLE_SOURCE_DIR}/paddle/pir") +set(NEWIR_BINARY_DIR "${PADDLE_BINARY_DIR}/paddle/pir") + +file(GLOB IR_SRCS "*.cc") + +file(GLOB IR_PARSER_SRCS "parser/*.cc") +list(APPEND IR_SRCS ${IR_PARSER_SRCS}) + +ir_library(pir_core SRCS ${IR_SRCS} DEPS ddim) diff --git a/paddle/ir/core/attribute.cc b/paddle/pir/core/attribute.cc similarity index 86% rename from paddle/ir/core/attribute.cc rename to paddle/pir/core/attribute.cc index 0eff9964292df..993076880fdda 100644 --- a/paddle/ir/core/attribute.cc +++ b/paddle/pir/core/attribute.cc @@ -12,11 +12,11 @@ // See the License for the specific language governing permissions and // limitations under the License. -#include "paddle/ir/core/attribute.h" -#include "paddle/ir/core/attribute_base.h" -#include "paddle/ir/core/dialect.h" +#include "paddle/pir/core/attribute.h" +#include "paddle/pir/core/attribute_base.h" +#include "paddle/pir/core/dialect.h" -namespace ir { +namespace pir { IrContext *Attribute::ir_context() const { return dialect().ir_context(); } TypeId Attribute::type_id() { return storage_->abstract_attribute().type_id(); } @@ -29,4 +29,4 @@ const Dialect &Attribute::dialect() const { return storage_->abstract_attribute().dialect(); } -} // namespace ir +} // namespace pir diff --git a/paddle/ir/core/attribute.h b/paddle/pir/core/attribute.h similarity index 87% rename from paddle/ir/core/attribute.h rename to paddle/pir/core/attribute.h index d83ea3b3c6045..86d6a62ceddfd 100644 --- a/paddle/ir/core/attribute.h +++ b/paddle/pir/core/attribute.h @@ -14,13 +14,13 @@ #pragma once -#include "paddle/ir/core/cast_utils.h" -#include "paddle/ir/core/type_id.h" +#include "paddle/pir/core/cast_utils.h" +#include "paddle/pir/core/type_id.h" constexpr char kAttrStopGradients[] = "stop_gradient"; constexpr char kAttrIsPersisable[] = "is_persisable"; -namespace ir { +namespace pir { class AttributeStorage; class AbstractAttribute; class IrContext; @@ -77,12 +77,12 @@ class IR_API Attribute { template bool isa() const { - return ir::isa(*this); + return pir::isa(*this); } template U dyn_cast() const { - return ir::dyn_cast(*this); + return pir::dyn_cast(*this); } friend struct std::hash; @@ -92,13 +92,13 @@ class IR_API Attribute { }; IR_API std::ostream &operator<<(std::ostream &os, Attribute attr); -} // namespace ir +} // namespace pir namespace std { template <> -struct hash { - std::size_t operator()(const ir::Attribute &obj) const { - return std::hash()(obj.storage_); +struct hash { + std::size_t operator()(const pir::Attribute &obj) const { + return std::hash()(obj.storage_); } }; } // namespace std diff --git a/paddle/ir/core/attribute_base.h b/paddle/pir/core/attribute_base.h similarity index 91% rename from paddle/ir/core/attribute_base.h rename to paddle/pir/core/attribute_base.h index daa3fed14f8a3..e0cbb0253700a 100644 --- a/paddle/ir/core/attribute_base.h +++ b/paddle/pir/core/attribute_base.h @@ -14,11 +14,11 @@ #pragma once -#include "paddle/ir/core/ir_context.h" -#include "paddle/ir/core/storage_manager.h" -#include "paddle/ir/core/type_id.h" +#include "paddle/pir/core/ir_context.h" +#include "paddle/pir/core/storage_manager.h" +#include "paddle/pir/core/type_id.h" -namespace ir { +namespace pir { class Dialect; /// @@ -155,7 +155,7 @@ struct IR_API AttributeManager { template static T get(IrContext *ctx, Args &&...args) { return get( - ctx, ir::TypeId::get(), std::forward(args)...); + ctx, pir::TypeId::get(), std::forward(args)...); } /// @@ -204,7 +204,7 @@ struct IR_API AttributeManager { /// template static void RegisterAttribute(IrContext *ctx) { - RegisterAttribute(ctx, ir::TypeId::get()); + RegisterAttribute(ctx, pir::TypeId::get()); } /// @@ -242,25 +242,25 @@ struct IR_API AttributeManager { /// /// \brief Add some necessary functions to the custom Attribute class. /// -#define DECLARE_ATTRIBUTE_UTILITY_FUNCTOR(concrete_attribute, storage_type) \ - using Storage = storage_type; \ - \ - const Storage *storage() const { \ - return static_cast(this->storage_); \ - } \ - \ - static ir::TypeId type_id() { \ - return ir::TypeId::get(); \ - } \ - \ - template \ - static bool classof(T val) { \ - return val.type_id() == type_id(); \ - } \ - \ - template \ - static concrete_attribute get(ir::IrContext *ctx, Args... args) { \ - return ir::AttributeManager::template get(ctx, \ - args...); \ +#define DECLARE_ATTRIBUTE_UTILITY_FUNCTOR(concrete_attribute, storage_type) \ + using Storage = storage_type; \ + \ + const Storage *storage() const { \ + return static_cast(this->storage_); \ + } \ + \ + static pir::TypeId type_id() { \ + return pir::TypeId::get(); \ + } \ + \ + template \ + static bool classof(T val) { \ + return val.type_id() == type_id(); \ + } \ + \ + template \ + static concrete_attribute get(pir::IrContext *ctx, Args... args) { \ + return pir::AttributeManager::template get(ctx, \ + args...); \ } -} // namespace ir +} // namespace pir diff --git a/paddle/ir/core/block.cc b/paddle/pir/core/block.cc similarity index 93% rename from paddle/ir/core/block.cc rename to paddle/pir/core/block.cc index 04d59e2582ebe..f92d532298150 100644 --- a/paddle/ir/core/block.cc +++ b/paddle/pir/core/block.cc @@ -12,15 +12,15 @@ // See the License for the specific language governing permissions and // limitations under the License. -#include "paddle/ir/core/block.h" +#include "paddle/pir/core/block.h" #include -#include "paddle/ir/core/enforce.h" -#include "paddle/ir/core/operation.h" -#include "paddle/ir/core/region.h" +#include "paddle/pir/core/enforce.h" +#include "paddle/pir/core/operation.h" +#include "paddle/pir/core/region.h" -namespace ir { +namespace pir { Block::~Block() { assert(use_empty() && "block destroyed still has uses."); clear(); @@ -93,4 +93,4 @@ bool Block::TopoOrderCheck(const OpListType &op_list) { return true; } -} // namespace ir +} // namespace pir diff --git a/paddle/ir/core/block.h b/paddle/pir/core/block.h similarity index 93% rename from paddle/ir/core/block.h rename to paddle/pir/core/block.h index 7e612d6318d36..3a8b4fafc345d 100644 --- a/paddle/ir/core/block.h +++ b/paddle/pir/core/block.h @@ -17,12 +17,12 @@ #include #include -#include "paddle/ir/core/block_operand.h" -#include "paddle/ir/core/dll_decl.h" -#include "paddle/ir/core/region.h" -#include "paddle/ir/core/use_iterator.h" +#include "paddle/pir/core/block_operand.h" +#include "paddle/pir/core/dll_decl.h" +#include "paddle/pir/core/region.h" +#include "paddle/pir/core/use_iterator.h" -namespace ir { +namespace pir { class Operation; class IR_API Block { @@ -89,4 +89,4 @@ class IR_API Block { Region::iterator position_; BlockOperand first_use_; }; -} // namespace ir +} // namespace pir diff --git a/paddle/ir/core/block_operand.cc b/paddle/pir/core/block_operand.cc similarity index 91% rename from paddle/ir/core/block_operand.cc rename to paddle/pir/core/block_operand.cc index f64a07fd50dfe..78dd9c0b5d14e 100644 --- a/paddle/ir/core/block_operand.cc +++ b/paddle/pir/core/block_operand.cc @@ -12,12 +12,12 @@ // See the License for the specific language governing permissions and // limitations under the License. -#include "paddle/ir/core/block_operand.h" -#include "paddle/ir/core/block.h" -#include "paddle/ir/core/block_operand_impl.h" -#include "paddle/ir/core/enforce.h" +#include "paddle/pir/core/block_operand.h" +#include "paddle/pir/core/block.h" +#include "paddle/pir/core/block_operand_impl.h" +#include "paddle/pir/core/enforce.h" -namespace ir { +namespace pir { #define CHECK_BLOCKOPEREND_NULL_IMPL(func_name) \ IR_ENFORCE(impl_, \ @@ -75,7 +75,7 @@ void BlockOperandImpl::set_source(Block *source) { InsertToUdChain(); } -BlockOperandImpl::BlockOperandImpl(Block *source, ir::Operation *owner) +BlockOperandImpl::BlockOperandImpl(Block *source, pir::Operation *owner) : source_(source), owner_(owner) { if (!source) { return; @@ -110,4 +110,4 @@ void BlockOperandImpl::RemoveFromUdChain() { BlockOperandImpl::~BlockOperandImpl() { RemoveFromUdChain(); } } // namespace detail -} // namespace ir +} // namespace pir diff --git a/paddle/ir/core/block_operand.h b/paddle/pir/core/block_operand.h similarity index 93% rename from paddle/ir/core/block_operand.h rename to paddle/pir/core/block_operand.h index ec55a90a1c65d..9895af86e7ed7 100644 --- a/paddle/ir/core/block_operand.h +++ b/paddle/pir/core/block_operand.h @@ -14,10 +14,10 @@ #pragma once -#include "paddle/ir/core/cast_utils.h" -#include "paddle/ir/core/type.h" +#include "paddle/pir/core/cast_utils.h" +#include "paddle/pir/core/type.h" -namespace ir { +namespace pir { class Operation; class Value; class Block; @@ -70,4 +70,4 @@ class IR_API BlockOperand { detail::BlockOperandImpl *impl_{nullptr}; }; -} // namespace ir +} // namespace pir diff --git a/paddle/ir/core/block_operand_impl.h b/paddle/pir/core/block_operand_impl.h similarity index 94% rename from paddle/ir/core/block_operand_impl.h rename to paddle/pir/core/block_operand_impl.h index 53d8257c10032..1e0f8659a9c10 100644 --- a/paddle/ir/core/block_operand_impl.h +++ b/paddle/pir/core/block_operand_impl.h @@ -14,9 +14,9 @@ #pragma once -#include "paddle/ir/core/block_operand.h" +#include "paddle/pir/core/block_operand.h" -namespace ir { +namespace pir { class Operation; class Block; @@ -58,4 +58,4 @@ class BlockOperandImpl { }; } // namespace detail -} // namespace ir +} // namespace pir diff --git a/paddle/ir/core/builder.cc b/paddle/pir/core/builder.cc similarity index 92% rename from paddle/ir/core/builder.cc rename to paddle/pir/core/builder.cc index 1bfbd2e2a8ca8..a91428ba99080 100644 --- a/paddle/ir/core/builder.cc +++ b/paddle/pir/core/builder.cc @@ -12,13 +12,13 @@ // See the License for the specific language governing permissions and // limitations under the License. -#include "paddle/ir/core/builder.h" -#include "paddle/ir/core/builtin_attribute.h" -#include "paddle/ir/core/builtin_type.h" -#include "paddle/ir/core/region.h" -#include "paddle/ir/core/value.h" +#include "paddle/pir/core/builder.h" +#include "paddle/pir/core/builtin_attribute.h" +#include "paddle/pir/core/builtin_type.h" +#include "paddle/pir/core/region.h" +#include "paddle/pir/core/value.h" -namespace ir { +namespace pir { /// Create an operation given the fields represented as an OperationState. Operation *Builder::Build(OperationArgument &&argument) { return Insert(Operation::Create(std::move(argument))); @@ -81,4 +81,4 @@ PointerAttribute Builder::pointer_attr(void *value) { return PointerAttribute::get(context_, value); } -} // namespace ir +} // namespace pir diff --git a/paddle/ir/core/builder.h b/paddle/pir/core/builder.h similarity index 92% rename from paddle/ir/core/builder.h rename to paddle/pir/core/builder.h index f3ae837ea9723..acb621e7808e7 100644 --- a/paddle/ir/core/builder.h +++ b/paddle/pir/core/builder.h @@ -16,11 +16,11 @@ #include -#include "paddle/ir/core/block.h" -#include "paddle/ir/core/ir_context.h" -#include "paddle/ir/core/operation.h" +#include "paddle/pir/core/block.h" +#include "paddle/pir/core/ir_context.h" +#include "paddle/pir/core/operation.h" -namespace ir { +namespace pir { class Type; class UInt8Type; class Int8Type; @@ -97,10 +97,10 @@ class Builder { IR_API Operation *Build(OperationArgument &&argument); /// Creates an operation with the given fields. - IR_API Operation *Build(const std::vector &inputs, + IR_API Operation *Build(const std::vector &inputs, const AttributeMap &attribute, - const std::vector &output_types, - ir::OpInfo op_info); + const std::vector &output_types, + pir::OpInfo op_info); /// Create an operation of specific op type at the current insertion point. template @@ -141,4 +141,4 @@ class Builder { Block::iterator insert_point_; }; -} // namespace ir +} // namespace pir diff --git a/paddle/ir/core/builtin_attribute.cc b/paddle/pir/core/builtin_attribute.cc similarity index 81% rename from paddle/ir/core/builtin_attribute.cc rename to paddle/pir/core/builtin_attribute.cc index 38ca80cb1f9d7..e14a424c32c8e 100644 --- a/paddle/ir/core/builtin_attribute.cc +++ b/paddle/pir/core/builtin_attribute.cc @@ -12,9 +12,9 @@ // See the License for the specific language governing permissions and // limitations under the License. -#include "paddle/ir/core/builtin_attribute.h" +#include "paddle/pir/core/builtin_attribute.h" -namespace ir { +namespace pir { bool BoolAttribute::data() const { return storage()->data(); } @@ -37,7 +37,7 @@ std::string StrAttribute::AsString() const { return storage()->AsString(); } size_t StrAttribute::size() const { return storage()->size(); } -StrAttribute StrAttribute::get(ir::IrContext* ctx, const std::string& value) { +StrAttribute StrAttribute::get(pir::IrContext* ctx, const std::string& value) { return AttributeManager::get(ctx, value); } @@ -79,14 +79,14 @@ ArrayAttributeStorage::~ArrayAttributeStorage() { } } -} // namespace ir - -IR_DEFINE_EXPLICIT_TYPE_ID(ir::StrAttribute) -IR_DEFINE_EXPLICIT_TYPE_ID(ir::BoolAttribute) -IR_DEFINE_EXPLICIT_TYPE_ID(ir::FloatAttribute) -IR_DEFINE_EXPLICIT_TYPE_ID(ir::DoubleAttribute) -IR_DEFINE_EXPLICIT_TYPE_ID(ir::Int32Attribute) -IR_DEFINE_EXPLICIT_TYPE_ID(ir::Int64Attribute) -IR_DEFINE_EXPLICIT_TYPE_ID(ir::ArrayAttribute) -IR_DEFINE_EXPLICIT_TYPE_ID(ir::PointerAttribute) -IR_DEFINE_EXPLICIT_TYPE_ID(ir::TypeAttribute) +} // namespace pir + +IR_DEFINE_EXPLICIT_TYPE_ID(pir::StrAttribute) +IR_DEFINE_EXPLICIT_TYPE_ID(pir::BoolAttribute) +IR_DEFINE_EXPLICIT_TYPE_ID(pir::FloatAttribute) +IR_DEFINE_EXPLICIT_TYPE_ID(pir::DoubleAttribute) +IR_DEFINE_EXPLICIT_TYPE_ID(pir::Int32Attribute) +IR_DEFINE_EXPLICIT_TYPE_ID(pir::Int64Attribute) +IR_DEFINE_EXPLICIT_TYPE_ID(pir::ArrayAttribute) +IR_DEFINE_EXPLICIT_TYPE_ID(pir::PointerAttribute) +IR_DEFINE_EXPLICIT_TYPE_ID(pir::TypeAttribute) diff --git a/paddle/ir/core/builtin_attribute.h b/paddle/pir/core/builtin_attribute.h similarity index 80% rename from paddle/ir/core/builtin_attribute.h rename to paddle/pir/core/builtin_attribute.h index 3969d962e1f4e..7d3f86144915c 100644 --- a/paddle/ir/core/builtin_attribute.h +++ b/paddle/pir/core/builtin_attribute.h @@ -14,11 +14,11 @@ #pragma once -#include "paddle/ir/core/attribute.h" -#include "paddle/ir/core/builtin_attribute_storage.h" -#include "paddle/ir/core/utils.h" +#include "paddle/pir/core/attribute.h" +#include "paddle/pir/core/builtin_attribute_storage.h" +#include "paddle/pir/core/utils.h" -namespace ir { +namespace pir { class IR_API BoolAttribute : public Attribute { public: using Attribute::Attribute; @@ -115,14 +115,14 @@ class IR_API ArrayAttribute : public Attribute { const std::vector& value); }; -} // namespace ir - -IR_EXPORT_DECLARE_EXPLICIT_TYPE_ID(ir::StrAttribute) -IR_EXPORT_DECLARE_EXPLICIT_TYPE_ID(ir::BoolAttribute) -IR_EXPORT_DECLARE_EXPLICIT_TYPE_ID(ir::FloatAttribute) -IR_EXPORT_DECLARE_EXPLICIT_TYPE_ID(ir::DoubleAttribute) -IR_EXPORT_DECLARE_EXPLICIT_TYPE_ID(ir::Int32Attribute) -IR_EXPORT_DECLARE_EXPLICIT_TYPE_ID(ir::Int64Attribute) -IR_EXPORT_DECLARE_EXPLICIT_TYPE_ID(ir::ArrayAttribute) -IR_EXPORT_DECLARE_EXPLICIT_TYPE_ID(ir::PointerAttribute) -IR_EXPORT_DECLARE_EXPLICIT_TYPE_ID(ir::TypeAttribute) +} // namespace pir + +IR_EXPORT_DECLARE_EXPLICIT_TYPE_ID(pir::StrAttribute) +IR_EXPORT_DECLARE_EXPLICIT_TYPE_ID(pir::BoolAttribute) +IR_EXPORT_DECLARE_EXPLICIT_TYPE_ID(pir::FloatAttribute) +IR_EXPORT_DECLARE_EXPLICIT_TYPE_ID(pir::DoubleAttribute) +IR_EXPORT_DECLARE_EXPLICIT_TYPE_ID(pir::Int32Attribute) +IR_EXPORT_DECLARE_EXPLICIT_TYPE_ID(pir::Int64Attribute) +IR_EXPORT_DECLARE_EXPLICIT_TYPE_ID(pir::ArrayAttribute) +IR_EXPORT_DECLARE_EXPLICIT_TYPE_ID(pir::PointerAttribute) +IR_EXPORT_DECLARE_EXPLICIT_TYPE_ID(pir::TypeAttribute) diff --git a/paddle/ir/core/builtin_attribute_storage.h b/paddle/pir/core/builtin_attribute_storage.h similarity index 95% rename from paddle/ir/core/builtin_attribute_storage.h rename to paddle/pir/core/builtin_attribute_storage.h index 624abaf004718..fd9dd6eb87128 100644 --- a/paddle/ir/core/builtin_attribute_storage.h +++ b/paddle/pir/core/builtin_attribute_storage.h @@ -18,13 +18,13 @@ #include #include -#include "paddle/ir/core/attribute.h" -#include "paddle/ir/core/attribute_base.h" -#include "paddle/ir/core/enforce.h" -#include "paddle/ir/core/type.h" -#include "paddle/ir/core/utils.h" +#include "paddle/pir/core/attribute.h" +#include "paddle/pir/core/attribute_base.h" +#include "paddle/pir/core/enforce.h" +#include "paddle/pir/core/type.h" +#include "paddle/pir/core/utils.h" -namespace ir { +namespace pir { #define DECLARE_BASE_TYPE_ATTRIBUTE_STORAGE(ConcreteStorage, BaseType) \ struct ConcreteStorage : public AttributeStorage { \ @@ -147,4 +147,4 @@ struct ArrayAttributeStorage : public AttributeStorage { const size_t size_; }; -} // namespace ir +} // namespace pir diff --git a/paddle/ir/core/builtin_dialect.cc b/paddle/pir/core/builtin_dialect.cc similarity index 87% rename from paddle/ir/core/builtin_dialect.cc rename to paddle/pir/core/builtin_dialect.cc index 375bf90d2b8fd..23ba43c3d292e 100644 --- a/paddle/ir/core/builtin_dialect.cc +++ b/paddle/pir/core/builtin_dialect.cc @@ -12,12 +12,12 @@ // See the License for the specific language governing permissions and // limitations under the License. -#include "paddle/ir/core/builtin_dialect.h" -#include "paddle/ir/core/builtin_attribute.h" -#include "paddle/ir/core/builtin_op.h" -#include "paddle/ir/core/builtin_type.h" +#include "paddle/pir/core/builtin_dialect.h" +#include "paddle/pir/core/builtin_attribute.h" +#include "paddle/pir/core/builtin_op.h" +#include "paddle/pir/core/builtin_type.h" -namespace ir { +namespace pir { BuiltinDialect::BuiltinDialect(IrContext *context) : Dialect(name(), context, TypeId::get()) { initialize(); @@ -59,6 +59,6 @@ void BuiltinDialect::initialize() { ConstantOp>(); } -} // namespace ir +} // namespace pir -IR_DEFINE_EXPLICIT_TYPE_ID(ir::BuiltinDialect) +IR_DEFINE_EXPLICIT_TYPE_ID(pir::BuiltinDialect) diff --git a/paddle/ir/core/builtin_dialect.h b/paddle/pir/core/builtin_dialect.h similarity index 82% rename from paddle/ir/core/builtin_dialect.h rename to paddle/pir/core/builtin_dialect.h index c5872f8142e7b..13e669102d8cc 100644 --- a/paddle/ir/core/builtin_dialect.h +++ b/paddle/pir/core/builtin_dialect.h @@ -14,17 +14,17 @@ #pragma once -#include "paddle/ir/core/dialect.h" +#include "paddle/pir/core/dialect.h" -namespace ir { +namespace pir { /// /// \brief Built-in Dialect: automatically registered into global IrContext, /// all built-in types defined in builtin_type.h will be registered in this /// Dialect. /// -class IR_API BuiltinDialect : public ir::Dialect { +class IR_API BuiltinDialect : public pir::Dialect { public: - explicit BuiltinDialect(ir::IrContext *context); + explicit BuiltinDialect(pir::IrContext *context); /// /// \brief Each Dialect needs to provide a name function to return the name of /// the Dialect. @@ -37,6 +37,6 @@ class IR_API BuiltinDialect : public ir::Dialect { void initialize(); }; -} // namespace ir +} // namespace pir -IR_EXPORT_DECLARE_EXPLICIT_TYPE_ID(ir::BuiltinDialect) +IR_EXPORT_DECLARE_EXPLICIT_TYPE_ID(pir::BuiltinDialect) diff --git a/paddle/ir/core/builtin_op.cc b/paddle/pir/core/builtin_op.cc similarity index 83% rename from paddle/ir/core/builtin_op.cc rename to paddle/pir/core/builtin_op.cc index 1feb4d691d99b..aba3ff9b282e4 100644 --- a/paddle/ir/core/builtin_op.cc +++ b/paddle/pir/core/builtin_op.cc @@ -12,13 +12,13 @@ // See the License for the specific language governing permissions and // limitations under the License. -#include "paddle/ir/core/builtin_op.h" -#include "paddle/ir/core/builtin_attribute.h" -#include "paddle/ir/core/builtin_type.h" -#include "paddle/ir/core/enforce.h" +#include "paddle/pir/core/builtin_op.h" #include "paddle/phi/core/enforce.h" +#include "paddle/pir/core/builtin_attribute.h" +#include "paddle/pir/core/builtin_type.h" +#include "paddle/pir/core/enforce.h" -namespace ir { +namespace pir { const char *ModuleOp::attributes_name[attributes_num] = {"program"}; // NOLINT @@ -38,7 +38,7 @@ Block *ModuleOp::block() { } ModuleOp ModuleOp::Create(IrContext *context, Program *pointer) { - ir::OpInfo info = context->GetRegisteredOpInfo(name()); + pir::OpInfo info = context->GetRegisteredOpInfo(name()); OperationArgument argument(info); argument.num_regions = 1; argument.AddAttribute("program", PointerAttribute::get(context, pointer)); @@ -77,7 +77,7 @@ void GetParameterOp::Build(Builder &builder, const std::string &name, Type type) { argument.attributes[attributes_name[0]] = - ir::StrAttribute::get(builder.ir_context(), name); + pir::StrAttribute::get(builder.ir_context(), name); argument.output_types.emplace_back(type); } @@ -105,7 +105,7 @@ void SetParameterOp::Build(Builder &builder, // NOLINT const std::string &name) { argument.AddOperand(parameter); argument.AddAttribute(attributes_name[0], - ir::StrAttribute::get(builder.ir_context(), name)); + pir::StrAttribute::get(builder.ir_context(), name)); } void SetParameterOp::Verify() const { VLOG(4) << "Verifying inputs, outputs and attributes for: SetParameterOp."; @@ -124,14 +124,18 @@ void SetParameterOp::Verify() const { void CombineOp::Build(Builder &builder, OperationArgument &argument, - const std::vector &inputs) { + const std::vector &inputs) { argument.inputs = inputs; - std::vector inputs_type(inputs.size()); - for (size_t idx = 0; idx < inputs.size(); ++idx) { - inputs_type[idx] = inputs[idx].type(); + if (inputs.size() == 0) { + argument.output_types.emplace_back(pir::Type()); + } else { + std::vector inputs_type(inputs.size()); + for (size_t idx = 0; idx < inputs.size(); ++idx) { + inputs_type[idx] = inputs[idx].type(); + } + argument.output_types.emplace_back( + pir::VectorType::get(builder.ir_context(), inputs_type)); } - argument.output_types.emplace_back( - ir::VectorType::get(builder.ir_context(), inputs_type)); } void CombineOp::Verify() const { @@ -167,11 +171,11 @@ const char *SliceOp::attributes_name[attributes_num] = {"index"}; // NOLINT void SliceOp::Build(Builder &builder, OperationArgument &argument, - const ir::OpResult &input, + const pir::OpResult &input, int index) { argument.inputs = {input}; argument.output_types.emplace_back(input.type() - .dyn_cast() + .dyn_cast() .data()[static_cast(index)]); } @@ -182,7 +186,7 @@ void SliceOp::Verify() const { input_size == 1, "The size %d of inputs must be equal to 1.", input_size); // inputs[0].type == Vector - auto input_type = (*this)->operand(0).type().dyn_cast(); + auto input_type = (*this)->operand(0).type().dyn_cast(); IR_ENFORCE(input_type, "The type %s of inputs[0] must be equal to VectorType.", input_type); @@ -197,10 +201,10 @@ void SliceOp::Verify() const { auto &attributes = this->attributes(); IR_ENFORCE(attributes.count("index") != 0, "The attributes must contains index."); - const ir::Attribute &attr = attributes.at("index"); - IR_ENFORCE(attr.isa(), + const pir::Attribute &attr = attributes.at("index"); + IR_ENFORCE(attr.isa(), "The attribute index must be INT32."); - auto index = attr.dyn_cast().data(); + auto index = attr.dyn_cast().data(); // index >= 0 and < inputs[0].size() IR_ENFORCE( @@ -222,12 +226,12 @@ void SliceOp::Verify() const { void SplitOp::Build(Builder &builder, OperationArgument &argument, - const ir::OpResult &input) { + const pir::OpResult &input) { argument.inputs = {input}; - for (size_t idx = 0; idx < input.type().dyn_cast().size(); + for (size_t idx = 0; idx < input.type().dyn_cast().size(); ++idx) { argument.output_types.emplace_back( - input.type().dyn_cast().data()[idx]); + input.type().dyn_cast().data()[idx]); } } @@ -277,13 +281,13 @@ void ConstantOp::Verify() const { Attribute ConstantOp::value() const { return attributes().at("value"); } -} // namespace ir +} // namespace pir -IR_DEFINE_EXPLICIT_TYPE_ID(ir::ModuleOp) -IR_DEFINE_EXPLICIT_TYPE_ID(ir::GetParameterOp) -IR_DEFINE_EXPLICIT_TYPE_ID(ir::SetParameterOp) -IR_DEFINE_EXPLICIT_TYPE_ID(ir::CombineOp) -IR_DEFINE_EXPLICIT_TYPE_ID(ir::SliceOp) -IR_DEFINE_EXPLICIT_TYPE_ID(ir::SplitOp) -IR_DEFINE_EXPLICIT_TYPE_ID(ir::ConstantLikeTrait) -IR_DEFINE_EXPLICIT_TYPE_ID(ir::ConstantOp) +IR_DEFINE_EXPLICIT_TYPE_ID(pir::ModuleOp) +IR_DEFINE_EXPLICIT_TYPE_ID(pir::GetParameterOp) +IR_DEFINE_EXPLICIT_TYPE_ID(pir::SetParameterOp) +IR_DEFINE_EXPLICIT_TYPE_ID(pir::CombineOp) +IR_DEFINE_EXPLICIT_TYPE_ID(pir::SliceOp) +IR_DEFINE_EXPLICIT_TYPE_ID(pir::SplitOp) +IR_DEFINE_EXPLICIT_TYPE_ID(pir::ConstantLikeTrait) +IR_DEFINE_EXPLICIT_TYPE_ID(pir::ConstantOp) diff --git a/paddle/ir/core/builtin_op.h b/paddle/pir/core/builtin_op.h similarity index 76% rename from paddle/ir/core/builtin_op.h rename to paddle/pir/core/builtin_op.h index ab2d0cb9efba6..fee0ca406a741 100644 --- a/paddle/ir/core/builtin_op.h +++ b/paddle/pir/core/builtin_op.h @@ -14,17 +14,17 @@ #pragma once -#include "paddle/ir/core/builder.h" -#include "paddle/ir/core/op_base.h" +#include "paddle/pir/core/builder.h" +#include "paddle/pir/core/op_base.h" -namespace ir { +namespace pir { class Program; class Block; /// /// \brief ModuleOp /// -class IR_API ModuleOp : public ir::Op { +class IR_API ModuleOp : public pir::Op { public: using Op::Op; static const char *name() { return "builtin.module"; } @@ -45,7 +45,7 @@ class IR_API ModuleOp : public ir::Op { /// \brief GetParameterOp: OpResult = GetParameterOp({StrAttribute, /// StrAttribute}) /// -class IR_API GetParameterOp : public ir::Op { +class IR_API GetParameterOp : public pir::Op { public: using Op::Op; static const char *name() { return "builtin.get_parameter"; } @@ -62,7 +62,7 @@ class IR_API GetParameterOp : public ir::Op { /// \brief SetParameterOp: SetParameterOp(OpOperand, {StrAttribute, /// StrAttribute}) /// -class IR_API SetParameterOp : public ir::Op { +class IR_API SetParameterOp : public pir::Op { public: using Op::Op; static const char *name() { return "builtin.set_parameter"; } @@ -78,7 +78,7 @@ class IR_API SetParameterOp : public ir::Op { /// /// \brief CombineOp: CombineOp(OpOperand) /// -class IR_API CombineOp : public ir::Op { +class IR_API CombineOp : public pir::Op { public: using Op::Op; @@ -90,23 +90,23 @@ class IR_API CombineOp : public ir::Op { static void Build(Builder &builder, // NOLINT OperationArgument &argument, // NOLINT - const std::vector &inputs); + const std::vector &inputs); void Verify() const; - std::vector inputs() { - std::vector inputs; + std::vector inputs() { + std::vector inputs; for (uint32_t idx = 0; idx < num_operands(); idx++) { inputs.push_back(operand_source(static_cast(idx))); } return inputs; } - ir::OpResult out() { return result(0); } + pir::OpResult out() { return result(0); } }; /// /// \brief SliceOp: SliceOp(OpOperand) /// -class IR_API SliceOp : public ir::Op { +class IR_API SliceOp : public pir::Op { public: using Op::Op; @@ -118,17 +118,17 @@ class IR_API SliceOp : public ir::Op { static void Build(Builder &builder, // NOLINT OperationArgument &argument, // NOLINT - const ir::OpResult &input, + const pir::OpResult &input, int index); void Verify() const; - ir::Value input() { return operand_source(0); } + pir::Value input() { return operand_source(0); } }; /// /// \brief SplitOp: SplitOp(OpOperand) /// -class IR_API SplitOp : public ir::Op { +class IR_API SplitOp : public pir::Op { public: using Op::Op; @@ -140,12 +140,12 @@ class IR_API SplitOp : public ir::Op { static void Build(Builder &builder, // NOLINT OperationArgument &argument, // NOLINT - const ir::OpResult &input); + const pir::OpResult &input); void Verify() const; - ir::Value input() { return operand_source(0); } - std::vector outputs() { - std::vector outputs; + pir::Value input() { return operand_source(0); } + std::vector outputs() { + std::vector outputs; for (uint32_t idx = 0; idx < num_results(); idx++) { outputs.push_back(result(static_cast(idx))); } @@ -180,13 +180,13 @@ class IR_API ConstantOp : public Op { Attribute value() const; }; -} // namespace ir +} // namespace pir -IR_EXPORT_DECLARE_EXPLICIT_TYPE_ID(ir::ModuleOp) -IR_EXPORT_DECLARE_EXPLICIT_TYPE_ID(ir::GetParameterOp) -IR_EXPORT_DECLARE_EXPLICIT_TYPE_ID(ir::SetParameterOp) -IR_EXPORT_DECLARE_EXPLICIT_TYPE_ID(ir::CombineOp) -IR_EXPORT_DECLARE_EXPLICIT_TYPE_ID(ir::SliceOp) -IR_EXPORT_DECLARE_EXPLICIT_TYPE_ID(ir::SplitOp) -IR_EXPORT_DECLARE_EXPLICIT_TYPE_ID(ir::ConstantLikeTrait) -IR_EXPORT_DECLARE_EXPLICIT_TYPE_ID(ir::ConstantOp) +IR_EXPORT_DECLARE_EXPLICIT_TYPE_ID(pir::ModuleOp) +IR_EXPORT_DECLARE_EXPLICIT_TYPE_ID(pir::GetParameterOp) +IR_EXPORT_DECLARE_EXPLICIT_TYPE_ID(pir::SetParameterOp) +IR_EXPORT_DECLARE_EXPLICIT_TYPE_ID(pir::CombineOp) +IR_EXPORT_DECLARE_EXPLICIT_TYPE_ID(pir::SliceOp) +IR_EXPORT_DECLARE_EXPLICIT_TYPE_ID(pir::SplitOp) +IR_EXPORT_DECLARE_EXPLICIT_TYPE_ID(pir::ConstantLikeTrait) +IR_EXPORT_DECLARE_EXPLICIT_TYPE_ID(pir::ConstantOp) diff --git a/paddle/ir/core/builtin_type.cc b/paddle/pir/core/builtin_type.cc similarity index 56% rename from paddle/ir/core/builtin_type.cc rename to paddle/pir/core/builtin_type.cc index 49a15484466b2..8d7de683e086a 100644 --- a/paddle/ir/core/builtin_type.cc +++ b/paddle/pir/core/builtin_type.cc @@ -12,12 +12,12 @@ // See the License for the specific language governing permissions and // limitations under the License. -#include "paddle/ir/core/builtin_type.h" +#include "paddle/pir/core/builtin_type.h" -namespace ir { +namespace pir { std::vector VectorType::data() const { return storage()->GetAsKey(); } -const ir::Type& DenseTensorType::dtype() const { return storage()->dtype_; } +const pir::Type& DenseTensorType::dtype() const { return storage()->dtype_; } const DenseTensorTypeStorage::Dim& DenseTensorType::dims() const { return storage()->dims_; @@ -32,20 +32,20 @@ const DenseTensorTypeStorage::LoD& DenseTensorType::lod() const { } const size_t& DenseTensorType::offset() const { return storage()->offset_; } -} // namespace ir - -IR_DEFINE_EXPLICIT_TYPE_ID(ir::UInt8Type) -IR_DEFINE_EXPLICIT_TYPE_ID(ir::Int8Type) -IR_DEFINE_EXPLICIT_TYPE_ID(ir::VectorType) -IR_DEFINE_EXPLICIT_TYPE_ID(ir::BFloat16Type) -IR_DEFINE_EXPLICIT_TYPE_ID(ir::Float16Type) -IR_DEFINE_EXPLICIT_TYPE_ID(ir::Float32Type) -IR_DEFINE_EXPLICIT_TYPE_ID(ir::Float64Type) -IR_DEFINE_EXPLICIT_TYPE_ID(ir::Int16Type) -IR_DEFINE_EXPLICIT_TYPE_ID(ir::Int32Type) -IR_DEFINE_EXPLICIT_TYPE_ID(ir::Int64Type) -IR_DEFINE_EXPLICIT_TYPE_ID(ir::IndexType) -IR_DEFINE_EXPLICIT_TYPE_ID(ir::BoolType) -IR_DEFINE_EXPLICIT_TYPE_ID(ir::Complex64Type) -IR_DEFINE_EXPLICIT_TYPE_ID(ir::Complex128Type) -IR_DEFINE_EXPLICIT_TYPE_ID(ir::DenseTensorType) +} // namespace pir + +IR_DEFINE_EXPLICIT_TYPE_ID(pir::UInt8Type) +IR_DEFINE_EXPLICIT_TYPE_ID(pir::Int8Type) +IR_DEFINE_EXPLICIT_TYPE_ID(pir::VectorType) +IR_DEFINE_EXPLICIT_TYPE_ID(pir::BFloat16Type) +IR_DEFINE_EXPLICIT_TYPE_ID(pir::Float16Type) +IR_DEFINE_EXPLICIT_TYPE_ID(pir::Float32Type) +IR_DEFINE_EXPLICIT_TYPE_ID(pir::Float64Type) +IR_DEFINE_EXPLICIT_TYPE_ID(pir::Int16Type) +IR_DEFINE_EXPLICIT_TYPE_ID(pir::Int32Type) +IR_DEFINE_EXPLICIT_TYPE_ID(pir::Int64Type) +IR_DEFINE_EXPLICIT_TYPE_ID(pir::IndexType) +IR_DEFINE_EXPLICIT_TYPE_ID(pir::BoolType) +IR_DEFINE_EXPLICIT_TYPE_ID(pir::Complex64Type) +IR_DEFINE_EXPLICIT_TYPE_ID(pir::Complex128Type) +IR_DEFINE_EXPLICIT_TYPE_ID(pir::DenseTensorType) diff --git a/paddle/ir/core/builtin_type.h b/paddle/pir/core/builtin_type.h similarity index 52% rename from paddle/ir/core/builtin_type.h rename to paddle/pir/core/builtin_type.h index a660f065376b2..3f0e7a1471703 100644 --- a/paddle/ir/core/builtin_type.h +++ b/paddle/pir/core/builtin_type.h @@ -15,14 +15,13 @@ #pragma once -#include "paddle/ir/core/builtin_type_storage.h" -#include "paddle/ir/core/type.h" +#include "paddle/pir/core/builtin_type_interfaces.h" +#include "paddle/pir/core/builtin_type_storage.h" +#include "paddle/pir/core/type.h" -namespace ir { +namespace pir { /// -/// \brief Define built-in parameterless types. Please add the necessary -/// interface functions for built-in types through the macro -/// DECLARE_TYPE_UTILITY_FUNCTOR. +/// \brief Define built-in parameterless types. /// /// NOTE(zhangbo9674): If you need to directly /// cache the object of this built-in type in IrContext, please overload the get @@ -31,7 +30,7 @@ namespace ir { /// /// The built-in type object get method is as follows: /// \code{cpp} -/// ir::IrContext *ctx = ir::IrContext::Instance(); +/// pir::IrContext *ctx = pir::IrContext::Instance(); /// Type fp32 = Float32Type::get(ctx); /// \endcode /// @@ -39,11 +38,10 @@ namespace ir { // NOTE(dev): Currently Int8 are not considered as a cached member // in IrContextImpl because it is not widely used. -class IR_API VectorType : public Type { +class IR_API VectorType + : public pir::Type::TypeBase { public: - using Type::Type; - - DECLARE_TYPE_UTILITY_FUNCTOR(VectorType, VectorTypeStorage); + using Base::Base; std::vector data() const; @@ -54,13 +52,14 @@ class IR_API VectorType : public Type { Type operator[](size_t index) const { return data()[index]; } }; -class DenseTensorType : public ir::Type { +class DenseTensorType : public pir::Type::TypeBase { public: - using Type::Type; - - DECLARE_TYPE_UTILITY_FUNCTOR(DenseTensorType, DenseTensorTypeStorage); + using Base::Base; - const ir::Type &dtype() const; + const pir::Type &dtype() const; const DenseTensorTypeStorage::Dim &dims() const; @@ -71,14 +70,13 @@ class DenseTensorType : public ir::Type { const size_t &offset() const; }; -#define DECLARE_BUILTIN_TYPE(__name) \ - class IR_API __name : public Type { \ - public: \ - using Type::Type; \ - \ - DECLARE_TYPE_UTILITY_FUNCTOR(__name, TypeStorage); \ - \ - static __name get(IrContext *context); \ +#define DECLARE_BUILTIN_TYPE(__name) \ + class IR_API __name : public ::pir::Type::TypeBase<__name, \ + ::pir::Type, \ + ::pir::TypeStorage> { \ + public: \ + using Base::Base; \ + static __name get(IrContext *context); \ }; #define FOREACH_BUILTIN_TYPE(__macro) \ @@ -101,20 +99,20 @@ FOREACH_BUILTIN_TYPE(DECLARE_BUILTIN_TYPE) #undef FOREACH_BUILTIN_TYPE #undef DECLARE_BUILTIN_TYPE -} // namespace ir - -IR_EXPORT_DECLARE_EXPLICIT_TYPE_ID(ir::UInt8Type) -IR_EXPORT_DECLARE_EXPLICIT_TYPE_ID(ir::Int8Type) -IR_EXPORT_DECLARE_EXPLICIT_TYPE_ID(ir::VectorType) -IR_EXPORT_DECLARE_EXPLICIT_TYPE_ID(ir::BFloat16Type) -IR_EXPORT_DECLARE_EXPLICIT_TYPE_ID(ir::Float16Type) -IR_EXPORT_DECLARE_EXPLICIT_TYPE_ID(ir::Float32Type) -IR_EXPORT_DECLARE_EXPLICIT_TYPE_ID(ir::Float64Type) -IR_EXPORT_DECLARE_EXPLICIT_TYPE_ID(ir::Int16Type) -IR_EXPORT_DECLARE_EXPLICIT_TYPE_ID(ir::Int32Type) -IR_EXPORT_DECLARE_EXPLICIT_TYPE_ID(ir::Int64Type) -IR_EXPORT_DECLARE_EXPLICIT_TYPE_ID(ir::BoolType) -IR_EXPORT_DECLARE_EXPLICIT_TYPE_ID(ir::IndexType) -IR_EXPORT_DECLARE_EXPLICIT_TYPE_ID(ir::Complex64Type) -IR_EXPORT_DECLARE_EXPLICIT_TYPE_ID(ir::Complex128Type) -IR_EXPORT_DECLARE_EXPLICIT_TYPE_ID(ir::DenseTensorType) +} // namespace pir + +IR_EXPORT_DECLARE_EXPLICIT_TYPE_ID(pir::UInt8Type) +IR_EXPORT_DECLARE_EXPLICIT_TYPE_ID(pir::Int8Type) +IR_EXPORT_DECLARE_EXPLICIT_TYPE_ID(pir::VectorType) +IR_EXPORT_DECLARE_EXPLICIT_TYPE_ID(pir::BFloat16Type) +IR_EXPORT_DECLARE_EXPLICIT_TYPE_ID(pir::Float16Type) +IR_EXPORT_DECLARE_EXPLICIT_TYPE_ID(pir::Float32Type) +IR_EXPORT_DECLARE_EXPLICIT_TYPE_ID(pir::Float64Type) +IR_EXPORT_DECLARE_EXPLICIT_TYPE_ID(pir::Int16Type) +IR_EXPORT_DECLARE_EXPLICIT_TYPE_ID(pir::Int32Type) +IR_EXPORT_DECLARE_EXPLICIT_TYPE_ID(pir::Int64Type) +IR_EXPORT_DECLARE_EXPLICIT_TYPE_ID(pir::BoolType) +IR_EXPORT_DECLARE_EXPLICIT_TYPE_ID(pir::IndexType) +IR_EXPORT_DECLARE_EXPLICIT_TYPE_ID(pir::Complex64Type) +IR_EXPORT_DECLARE_EXPLICIT_TYPE_ID(pir::Complex128Type) +IR_EXPORT_DECLARE_EXPLICIT_TYPE_ID(pir::DenseTensorType) diff --git a/paddle/cinn/optim/tensor_write_tell.cc b/paddle/pir/core/builtin_type_interfaces.cc similarity index 72% rename from paddle/cinn/optim/tensor_write_tell.cc rename to paddle/pir/core/builtin_type_interfaces.cc index 9f0f5747c3f3d..9084bffc7a197 100644 --- a/paddle/cinn/optim/tensor_write_tell.cc +++ b/paddle/pir/core/builtin_type_interfaces.cc @@ -1,4 +1,4 @@ -// Copyright (c) 2021 CINN Authors. All Rights Reserved. +// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. @@ -12,8 +12,7 @@ // See the License for the specific language governing permissions and // limitations under the License. -#include "paddle/cinn/optim/tensor_write_tell.h" +#include "paddle/pir/core/builtin_type_interfaces.h" +#include "paddle/pir/core/type_id.h" -namespace cinn { -namespace optim {} // namespace optim -} // namespace cinn +IR_DEFINE_EXPLICIT_TYPE_ID(pir::ShapedTypeInterface) diff --git a/paddle/pir/core/builtin_type_interfaces.h b/paddle/pir/core/builtin_type_interfaces.h new file mode 100644 index 0000000000000..f736c1a631b48 --- /dev/null +++ b/paddle/pir/core/builtin_type_interfaces.h @@ -0,0 +1,159 @@ +// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include +#include "paddle/phi/core/tensor_base.h" +#include "paddle/pir/core/cast_utils.h" +#include "paddle/pir/core/enforce.h" +#include "paddle/pir/core/type.h" + +namespace details { + +template +constexpr auto begin_impl(RangeT &&range) + -> decltype(std::begin(std::forward(range))) { + return std::begin(std::forward(range)); +} + +template +constexpr auto end_impl(RangeT &&range) + -> decltype(std::end(std::forward(range))) { + return std::end(std::forward(range)); +} + +/// Returns the begin iterator to \p range using `std::begin` and +/// function found through Argument-Dependent Lookup (ADL). +template +constexpr auto adl_begin(RangeT &&range) + -> decltype(begin_impl(std::forward(range))) { + return begin_impl(std::forward(range)); +} + +/// Returns the end iterator to \p range using `std::end` and +/// functions found through Argument-Dependent Lookup (ADL). +template +constexpr auto adl_end(RangeT &&range) + -> decltype(end_impl(std::forward(range))) { + return end_impl(std::forward(range)); +} + +/// Provide wrappers to std::any_of which take ranges instead of having to pass +/// begin/end explicitly. +template +bool any_of(R &&Range, UnaryPredicate P) { + return std::any_of(adl_begin(Range), adl_end(Range), P); +} + +/// Wrapper function around std::count_if to count the number of times an +/// element satisfying a given predicate occurs in a range. +template +auto count_if(R &&Range, UnaryPredicate P) { + return std::count_if(adl_begin(Range), adl_end(Range), P); +} + +} // namespace details +namespace pir { +class ShapedTypeInterface : public pir::TypeInterfaceBase { + public: + using DDim = phi::DDim; + using DataType = pir::Type; + struct Concept { + /// Defined these methods with the interface. + explicit Concept(DataType (*get_element_type)(pir::Type), + DDim (*get_shape)(pir::Type)) + : get_element_type_(get_element_type), get_shape_(get_shape) {} + + DataType (*get_element_type_)(pir::Type); + DDim (*get_shape_)(pir::Type); + }; + + template + struct Model : public Concept { + static inline DataType getElementType(pir::Type type) { + return pir::cast(type).dtype(); + } + + static inline DDim getShape(pir::Type type) { + return pir::cast(type).dims(); + } + + Model() : Concept(getElementType, getShape) {} + }; + + /// Constructor + ShapedTypeInterface(pir::Type type, Concept *impl) + : pir::TypeInterfaceBase(type), impl_(impl) {} + + /// Get the element type. + DataType getElementType() const { return impl_->get_element_type_(*this); } + + /// Get the shape of this type. + DDim getShape() const { return impl_->get_shape_(*this); } + + static constexpr int64_t kDynamic = std::numeric_limits::min(); + + /// Check whether this type is ranked, currently return true. + bool hasRank() const { return true; } + + /// If this is a ranked type, return the rank. Otherwise, abort. + int64_t getRank() const { + IR_ENFORCE((*this).hasRank(), "Cannot query rank of unranked shaped type."); + return (*this).getShape().size(); + } + + /// Check whether the given dimension size is a dynamic dimension. + static constexpr bool isDynamic(int64_t dValue) { return dValue == kDynamic; } + + /// Check whether the given shape has any size indicating a dynamic dimension. + static bool isDynamicShape(DDim dSizes) { + return ::details::any_of(vectorize(dSizes), + [](int64_t dSize) { return isDynamic(dSize); }); + } + + /// Check whether shape has any size indicating a dynamic dimension. + bool hasStaticShape() const { + return (*this).hasRank() && + !pir::ShapedTypeInterface::isDynamicShape((*this).getShape()); + } + + /// Check whether the given dimension has a dynamic size. + /// Aborts for unranked types. + bool isDynamicDim(unsigned idx) const { + IR_ENFORCE(idx < getRank(), "Invalid index for shaped type."); + return pir::ShapedTypeInterface::isDynamic((*this).getShape()[idx]); + } + + /// Get the number of dimensions with dynamic size for a ranked type. + /// Aborts for unranked types. + int64_t getNumDynamicDims() const { + return ::details::count_if(vectorize((*this).getShape()), + pir::ShapedTypeInterface::isDynamic); + } + + /// Get the size of the specified dimension for a ranked type. + /// Aborts for unranked types. + int64_t getDimSize(unsigned idx) const { + IR_ENFORCE(idx < getRank(), "Invalid index for shaped type."); + return (*this).getShape()[idx]; + } + + private: + Concept *impl_; +}; + +} // namespace pir + +IR_DECLARE_EXPLICIT_TYPE_ID(pir::ShapedTypeInterface) diff --git a/paddle/ir/core/builtin_type_storage.h b/paddle/pir/core/builtin_type_storage.h similarity index 87% rename from paddle/ir/core/builtin_type_storage.h rename to paddle/pir/core/builtin_type_storage.h index 4488b28b07fa2..b8b18d09ddd26 100644 --- a/paddle/ir/core/builtin_type_storage.h +++ b/paddle/pir/core/builtin_type_storage.h @@ -14,11 +14,11 @@ #pragma once -#include "paddle/ir/core/type.h" -#include "paddle/ir/core/type_base.h" -#include "paddle/ir/core/utils.h" #include "paddle/phi/common/layout.h" #include "paddle/phi/core/ddim.h" +#include "paddle/pir/core/type.h" +#include "paddle/pir/core/type_base.h" +#include "paddle/pir/core/utils.h" namespace std { /// @@ -37,7 +37,7 @@ struct hash> { } // namespace std -namespace ir { +namespace pir { /// /// \brief Define Parametric TypeStorage for DenseTensorType. /// @@ -46,16 +46,16 @@ namespace ir { /// (3)define HashValue method, (4)overload operator==. /// -struct DenseTensorTypeStorage : public ir::TypeStorage { +struct DenseTensorTypeStorage : public pir::TypeStorage { /// /// \brief Declare ParamKey according to parameter type. /// using DataLayout = phi::DataLayout; using Dim = phi::DDim; using LoD = std::vector>; - using ParamKey = std::tuple; + using ParamKey = std::tuple; - DenseTensorTypeStorage(const ir::Type& dtype, + DenseTensorTypeStorage(const pir::Type& dtype, const Dim& dims, const DataLayout& layout, const LoD& lod, @@ -85,22 +85,22 @@ struct DenseTensorTypeStorage : public ir::TypeStorage { std::size_t hash_value = 0; // hash dtype hash_value = - ir::hash_combine(hash_value, std::hash()(std::get<0>(key))); + pir::hash_combine(hash_value, std::hash()(std::get<0>(key))); // hash dims hash_value = - ir::hash_combine(hash_value, std::hash()(std::get<1>(key))); + pir::hash_combine(hash_value, std::hash()(std::get<1>(key))); // hash layout - hash_value = ir::hash_combine( + hash_value = pir::hash_combine( hash_value, std::hash::type>()( static_cast::type>( std::get<2>(key)))); // hash lod hash_value = - ir::hash_combine(hash_value, std::hash()(std::get<3>(key))); + pir::hash_combine(hash_value, std::hash()(std::get<3>(key))); // hash offset hash_value = - ir::hash_combine(hash_value, std::hash()(std::get<4>(key))); + pir::hash_combine(hash_value, std::hash()(std::get<4>(key))); return hash_value; } @@ -119,7 +119,7 @@ struct DenseTensorTypeStorage : public ir::TypeStorage { /// \brief DenseTensorTypeStorage include five parameters: dims, dtype, /// layout, lod, offset. /// - ir::Type dtype_; + pir::Type dtype_; Dim dims_; DataLayout layout_; LoD lod_; @@ -183,4 +183,4 @@ struct VectorTypeStorage : public TypeStorage { size_t size_; }; -} // namespace ir +} // namespace pir diff --git a/paddle/ir/core/cast_utils.h b/paddle/pir/core/cast_utils.h similarity index 84% rename from paddle/ir/core/cast_utils.h rename to paddle/pir/core/cast_utils.h index dcc4b89fe8b04..db9f864aaabc3 100644 --- a/paddle/ir/core/cast_utils.h +++ b/paddle/pir/core/cast_utils.h @@ -14,9 +14,10 @@ #pragma once +#include #include -namespace ir { +namespace pir { /// /// \brief The template function actually called by isa_wrap. /// @@ -114,7 +115,7 @@ struct ReturnTypeDuduction { /// /// cast From to To /// -template +template struct cast_impl { // This _is_ a simple type, just cast it. static typename ReturnTypeDuduction::type call(const From &Val) { @@ -125,7 +126,15 @@ struct cast_impl { }; template -inline typename ReturnTypeDuduction::type cast(From &Val) { // NOLINT +inline decltype(auto) cast(const From &Val) { + if (!isa(Val)) { + throw("cast() argument of incompatible type!"); + } + return cast_impl::call(Val); +} + +template +inline decltype(auto) cast(From &Val) { // NOLINT if (!isa(Val)) { throw("cast() argument of incompatible type!"); } @@ -133,25 +142,32 @@ inline typename ReturnTypeDuduction::type cast(From &Val) { // NOLINT } template -inline typename ReturnTypeDuduction::type cast(From *Val) { +inline decltype(auto) cast(From *Val) { if (!isa(Val)) { throw("cast() argument of incompatible type!"); } return cast_impl::call(Val); } +template +inline decltype(auto) cast(std::unique_ptr &&Val) { + if (!isa(Val)) { + throw("cast() argument of incompatible type!"); + } + return cast_impl>::call(std::move(Val)); +} + /// /// \brief dyn_cast From to To. /// template -inline std::decay_t::type> dyn_cast( - From &Val) { // NOLINT +inline decltype(auto) dyn_cast(From &Val) { // NOLINT return isa(Val) ? cast(Val) : nullptr; } template -inline typename ReturnTypeDuduction::type dyn_cast(From *Val) { +inline decltype(auto) dyn_cast(From *Val) { return isa(Val) ? cast(Val) : nullptr; } -} // namespace ir +} // namespace pir diff --git a/paddle/ir/core/dialect.cc b/paddle/pir/core/dialect.cc similarity index 88% rename from paddle/ir/core/dialect.cc rename to paddle/pir/core/dialect.cc index 0a4a6cc3b3854..e6831e977fa31 100644 --- a/paddle/ir/core/dialect.cc +++ b/paddle/pir/core/dialect.cc @@ -12,10 +12,10 @@ // See the License for the specific language governing permissions and // limitations under the License. -#include "paddle/ir/core/dialect.h" +#include "paddle/pir/core/dialect.h" -namespace ir { -Dialect::Dialect(std::string name, ir::IrContext *context, ir::TypeId id) +namespace pir { +Dialect::Dialect(std::string name, pir::IrContext *context, pir::TypeId id) : name_(std::move(name)), context_(context), id_(id) {} Dialect::~Dialect() = default; @@ -32,4 +32,4 @@ IrContext *DialectInterface::ir_context() const { return dialect_->ir_context(); } -} // namespace ir +} // namespace pir diff --git a/paddle/ir/core/dialect.h b/paddle/pir/core/dialect.h similarity index 94% rename from paddle/ir/core/dialect.h rename to paddle/pir/core/dialect.h index f07a4242f362c..07debaf196041 100644 --- a/paddle/ir/core/dialect.h +++ b/paddle/pir/core/dialect.h @@ -17,15 +17,15 @@ #include #include -#include "paddle/ir/core/attribute.h" -#include "paddle/ir/core/attribute_base.h" -#include "paddle/ir/core/dialect_interface.h" -#include "paddle/ir/core/enforce.h" -#include "paddle/ir/core/ir_context.h" -#include "paddle/ir/core/op_base.h" -#include "paddle/ir/core/type_base.h" +#include "paddle/pir/core/attribute.h" +#include "paddle/pir/core/attribute_base.h" +#include "paddle/pir/core/dialect_interface.h" +#include "paddle/pir/core/enforce.h" +#include "paddle/pir/core/ir_context.h" +#include "paddle/pir/core/op_base.h" +#include "paddle/pir/core/type_base.h" -namespace ir { +namespace pir { class Operation; class IrPrinter; @@ -174,4 +174,4 @@ class IR_API Dialect { std::unordered_map> registered_interfaces_; }; -} // namespace ir +} // namespace pir diff --git a/paddle/ir/core/dialect_interface.h b/paddle/pir/core/dialect_interface.h similarity index 96% rename from paddle/ir/core/dialect_interface.h rename to paddle/pir/core/dialect_interface.h index e24b3481f4ef4..7cb2b89de03eb 100644 --- a/paddle/ir/core/dialect_interface.h +++ b/paddle/pir/core/dialect_interface.h @@ -14,9 +14,9 @@ #pragma once -#include "paddle/ir/core/type_id.h" +#include "paddle/pir/core/type_id.h" -namespace ir { +namespace pir { class Dialect; class IrContext; /// @@ -64,4 +64,4 @@ class IR_API DialectInterface { TypeId interface_id_; }; -} // namespace ir +} // namespace pir diff --git a/paddle/ir/core/dll_decl.h b/paddle/pir/core/dll_decl.h similarity index 100% rename from paddle/ir/core/dll_decl.h rename to paddle/pir/core/dll_decl.h diff --git a/paddle/ir/core/enforce.h b/paddle/pir/core/enforce.h similarity index 95% rename from paddle/ir/core/enforce.h rename to paddle/pir/core/enforce.h index 10735297f305d..a3b1401b64d25 100644 --- a/paddle/ir/core/enforce.h +++ b/paddle/pir/core/enforce.h @@ -30,7 +30,7 @@ inline bool is_error(const T& stat) { return !stat; } -namespace ir { +namespace pir { class IrNotMetException : public std::exception { public: explicit IrNotMetException(const std::string& str) : err_str_(str) {} @@ -44,7 +44,7 @@ class IrNotMetException : public std::exception { #define IR_THROW(...) \ do { \ try { \ - throw ir::IrNotMetException( \ + throw pir::IrNotMetException( \ paddle::string::Sprintf("Error occured at: %s:%d :\n%s", \ __FILE__, \ __LINE__, \ @@ -60,7 +60,7 @@ class IrNotMetException : public std::exception { bool __cond__(COND); \ if (UNLIKELY(is_error(__cond__))) { \ try { \ - throw ir::IrNotMetException( \ + throw pir::IrNotMetException( \ paddle::string::Sprintf("Error occured at: %s:%d :\n%s", \ __FILE__, \ __LINE__, \ @@ -72,4 +72,4 @@ class IrNotMetException : public std::exception { } \ } while (0) -} // namespace ir +} // namespace pir diff --git a/paddle/ir/core/op_base.cc b/paddle/pir/core/interface_support.cc similarity index 72% rename from paddle/ir/core/op_base.cc rename to paddle/pir/core/interface_support.cc index 6f6dca0cdc125..19cba9de0bd85 100644 --- a/paddle/ir/core/op_base.cc +++ b/paddle/pir/core/interface_support.cc @@ -12,20 +12,22 @@ // See the License for the specific language governing permissions and // limitations under the License. -#include "paddle/ir/core/op_base.h" -namespace ir { -InterfaceValue::~InterfaceValue() { +#include "paddle/pir/core/interface_support.h" + +namespace pir { +details::InterfaceValue::~InterfaceValue() { if (model_) free(model_); } -InterfaceValue::InterfaceValue(InterfaceValue&& val) noexcept { +details::InterfaceValue::InterfaceValue(InterfaceValue&& val) noexcept { type_id_ = val.type_id_; model_ = val.model_; val.model_ = nullptr; } -InterfaceValue& InterfaceValue::operator=(InterfaceValue&& val) noexcept { +details::InterfaceValue& details::InterfaceValue::operator=( + InterfaceValue&& val) noexcept { swap(std::move(val)); return *this; } -} // namespace ir +} // namespace pir diff --git a/paddle/pir/core/interface_support.h b/paddle/pir/core/interface_support.h new file mode 100644 index 0000000000000..df8f776d7b87b --- /dev/null +++ b/paddle/pir/core/interface_support.h @@ -0,0 +1,122 @@ +// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include "paddle/pir/core/enforce.h" +#include "paddle/pir/core/interface_value.h" + +namespace pir { +namespace details { +template +class ConstructInterfacesOrTraits { + public: + /// Construct method for interfaces. + static details::InterfaceValue *interface( + details::InterfaceValue *p_interface) { + (void)std::initializer_list{ + 0, (PlacementConstrctInterface(p_interface), 0)...}; + return p_interface; + } + + /// Construct method for traits. + static TypeId *trait(TypeId *p_trait) { + (void)std::initializer_list{ + 0, (PlacementConstrctTrait(p_trait), 0)...}; + return p_trait; + } + + private: + /// Placement new interface. + template + static void PlacementConstrctInterface( + details::InterfaceValue *&p_interface) { // NOLINT + p_interface->swap(details::InterfaceValue::get()); + VLOG(6) << "New a interface: id[" + << (p_interface->type_id()).AsOpaquePointer() << "]."; + ++p_interface; + } + + /// Placement new trait. + template + static void PlacementConstrctTrait(pir::TypeId *&p_trait) { // NOLINT + *p_trait = TypeId::get(); + VLOG(6) << "New a trait: id[" << p_trait->AsOpaquePointer() << "]."; + ++p_trait; + } +}; + +/// Specialized for tuple type. +template +class ConstructInterfacesOrTraits> { + public: + /// Construct method for interfaces. + static details::InterfaceValue *interface( + details::InterfaceValue *p_interface) { + return ConstructInterfacesOrTraits::interface( + p_interface); + } + + /// Construct method for traits. + static TypeId *trait(TypeId *p_trait) { + return ConstructInterfacesOrTraits::trait(p_trait); + } +}; + +template +void *LookUp(const TypeId &interface_id, + const uint32_t num_interfaces, + const uint32_t num_traits, + const T *t) { + if (num_interfaces > 0) { + const details::InterfaceValue *p_first_interface = + reinterpret_cast( + reinterpret_cast(t) - sizeof(TypeId) * num_traits - + sizeof(details::InterfaceValue) * num_interfaces); + size_t left = 0, right = num_interfaces; + while (left < right) { + size_t mid = (left + right) / 2; + if ((p_first_interface + mid)->type_id() == interface_id) { + return (p_first_interface + mid)->model(); + } else if ((p_first_interface + mid)->type_id() < interface_id) { + left = mid + 1; + } else { + right = mid; + } + } + } + return nullptr; +} + +template +std::vector GetInterfaceMap() { + constexpr size_t interfaces_num = std::tuple_size::value; + std::vector interfaces_map(interfaces_num); + ConstructInterfacesOrTraits::interface( + interfaces_map.data()); + return interfaces_map; +} + +template +std::vector GetTraitSet() { + constexpr size_t traits_num = std::tuple_size::value; + std::vector trait_set(traits_num); + auto p_first_trait = trait_set.data(); + ConstructInterfacesOrTraits::trait(p_first_trait); + return trait_set; +} + +} // namespace details + +} // namespace pir diff --git a/paddle/pir/core/interface_value.h b/paddle/pir/core/interface_value.h new file mode 100644 index 0000000000000..fe7bc6d9ca2a8 --- /dev/null +++ b/paddle/pir/core/interface_value.h @@ -0,0 +1,67 @@ +// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once +#include "paddle/pir/core/type_id.h" +#include "paddle/pir/core/utils.h" + +namespace pir { + +namespace details { +class IR_API InterfaceValue { + public: + template + static InterfaceValue get() { + InterfaceValue val; + val.type_id_ = TypeId::get(); + val.model_ = malloc(sizeof(typename T::template Model)); + if (val.model_ == nullptr) { + throw("Alloc memory for interface failed."); + } + static_assert(std::is_trivially_destructible< + typename T::template Model>::value, + "interface models must be trivially destructible"); + new (val.model_) typename T::template Model(); + return val; + } + TypeId type_id() const { return type_id_; } + void *model() const { return model_; } + + InterfaceValue() = default; + explicit InterfaceValue(TypeId type_id) : type_id_(type_id) {} + InterfaceValue(const InterfaceValue &) = delete; + InterfaceValue(InterfaceValue &&) noexcept; + InterfaceValue &operator=(const InterfaceValue &) = delete; + InterfaceValue &operator=(InterfaceValue &&) noexcept; + ~InterfaceValue(); + void swap(InterfaceValue &&val) { + using std::swap; + swap(type_id_, val.type_id_); + swap(model_, val.model_); + } + + /// + /// \brief Comparison operations. + /// + inline bool operator<(const InterfaceValue &other) const { + return type_id_ < other.type_id_; + } + + private: + TypeId type_id_; + void *model_{nullptr}; +}; + +} // namespace details +} // namespace pir diff --git a/paddle/ir/core/ir_context.cc b/paddle/pir/core/ir_context.cc similarity index 82% rename from paddle/ir/core/ir_context.cc rename to paddle/pir/core/ir_context.cc index 9fe79ac84b6a4..bfc05fabcf35b 100644 --- a/paddle/ir/core/ir_context.cc +++ b/paddle/pir/core/ir_context.cc @@ -12,19 +12,19 @@ // See the License for the specific language governing permissions and // limitations under the License. -#include "paddle/ir/core/ir_context.h" +#include "paddle/pir/core/ir_context.h" #include -#include "paddle/ir/core/attribute_base.h" -#include "paddle/ir/core/builtin_dialect.h" -#include "paddle/ir/core/builtin_type.h" -#include "paddle/ir/core/dialect.h" -#include "paddle/ir/core/op_info_impl.h" -#include "paddle/ir/core/spin_lock.h" -#include "paddle/ir/core/type_base.h" +#include "paddle/pir/core/attribute_base.h" +#include "paddle/pir/core/builtin_dialect.h" +#include "paddle/pir/core/builtin_type.h" +#include "paddle/pir/core/dialect.h" +#include "paddle/pir/core/op_info_impl.h" +#include "paddle/pir/core/spin_lock.h" +#include "paddle/pir/core/type_base.h" -namespace ir { +namespace pir { // The implementation class of the IrContext class, cache registered // AbstractType, TypeStorage, AbstractAttribute, AttributeStorage, Dialect. class IrContextImpl { @@ -32,7 +32,7 @@ class IrContextImpl { IrContextImpl() = default; ~IrContextImpl() { - std::lock_guard guard(destructor_lock_); + std::lock_guard guard(destructor_lock_); for (auto &abstract_type_map : registed_abstract_types_) { delete abstract_type_map.second; } @@ -54,48 +54,48 @@ class IrContextImpl { registed_op_infos_.clear(); } - void RegisterAbstractType(ir::TypeId type_id, AbstractType *abstract_type) { - std::lock_guard guard(registed_abstract_types_lock_); + void RegisterAbstractType(pir::TypeId type_id, AbstractType *abstract_type) { + std::lock_guard guard(registed_abstract_types_lock_); VLOG(6) << "Register an abstract_type of: [TypeId_hash=" - << std::hash()(type_id) + << std::hash()(type_id) << ", AbstractType_ptr=" << abstract_type << "]."; registed_abstract_types_.emplace(type_id, abstract_type); } - AbstractType *GetAbstractType(ir::TypeId type_id) { - std::lock_guard guard(registed_abstract_types_lock_); + AbstractType *GetAbstractType(pir::TypeId type_id) { + std::lock_guard guard(registed_abstract_types_lock_); auto iter = registed_abstract_types_.find(type_id); if (iter != registed_abstract_types_.end()) { VLOG(6) << "Found a cached abstract_type of: [TypeId_hash=" - << std::hash()(type_id) + << std::hash()(type_id) << ", AbstractType_ptr=" << iter->second << "]."; return iter->second; } LOG(WARNING) << "No cache found abstract_type of: [TypeId_hash=" - << std::hash()(type_id) << "]."; + << std::hash()(type_id) << "]."; return nullptr; } - void RegisterAbstractAttribute(ir::TypeId type_id, + void RegisterAbstractAttribute(pir::TypeId type_id, AbstractAttribute *abstract_attribute) { - std::lock_guard guard(registed_abstract_attributes_lock_); + std::lock_guard guard(registed_abstract_attributes_lock_); VLOG(6) << "Register an abstract_attribute of: [TypeId_hash=" - << std::hash()(type_id) + << std::hash()(type_id) << ", AbstractAttribute_ptr=" << abstract_attribute << "]."; registed_abstract_attributes_.emplace(type_id, abstract_attribute); } - AbstractAttribute *GetAbstractAttribute(ir::TypeId type_id) { - std::lock_guard guard(registed_abstract_attributes_lock_); + AbstractAttribute *GetAbstractAttribute(pir::TypeId type_id) { + std::lock_guard guard(registed_abstract_attributes_lock_); auto iter = registed_abstract_attributes_.find(type_id); if (iter != registed_abstract_attributes_.end()) { VLOG(4) << "Found a cached abstract_attribute of: [TypeId_hash=" - << std::hash()(type_id) + << std::hash()(type_id) << ", AbstractAttribute_ptr=" << iter->second << "]."; return iter->second; } LOG(WARNING) << "No cache found abstract_attribute of: [TypeId_hash=" - << std::hash()(type_id) << "]."; + << std::hash()(type_id) << "]."; return nullptr; } @@ -104,14 +104,14 @@ class IrContextImpl { } void RegisterOpInfo(const std::string &name, OpInfo info) { - std::lock_guard guard(registed_op_infos_lock_); + std::lock_guard guard(registed_op_infos_lock_); VLOG(6) << "Register an operation of: [Name=" << name << ", OpInfo ptr=" << info.AsOpaquePointer() << "]."; registed_op_infos_.emplace(name, info); } OpInfo GetOpInfo(const std::string &name) { - std::lock_guard guard(registed_op_infos_lock_); + std::lock_guard guard(registed_op_infos_lock_); auto iter = registed_op_infos_.find(name); if (iter != registed_op_infos_.end()) { VLOG(8) << "Found a cached OpInfo of: [name=" << name @@ -124,7 +124,7 @@ class IrContextImpl { const OpInfoMap ®istered_op_info_map() { return registed_op_infos_; } void RegisterDialect(std::string name, Dialect *dialect) { - std::lock_guard guard(registed_dialect_lock_); + std::lock_guard guard(registed_dialect_lock_); VLOG(6) << "Register a dialect of: [name=" << name << ", dialect_ptr=" << dialect << "]."; registed_dialect_.emplace(name, dialect); @@ -135,7 +135,7 @@ class IrContextImpl { } Dialect *GetDialect(const std::string &name) { - std::lock_guard guard(registed_dialect_lock_); + std::lock_guard guard(registed_dialect_lock_); auto iter = registed_dialect_.find(name); if (iter != registed_dialect_.end()) { VLOG(6) << "Found a cached dialect of: [name=" << name @@ -148,7 +148,7 @@ class IrContextImpl { // Cached AbstractType instances. std::unordered_map registed_abstract_types_; - ir::SpinLock registed_abstract_types_lock_; + pir::SpinLock registed_abstract_types_lock_; // TypeStorage uniquer and cache instances. StorageManager registed_type_storage_manager_; // Cache some built-in type objects. @@ -168,19 +168,19 @@ class IrContextImpl { // Cached AbstractAttribute instances. std::unordered_map registed_abstract_attributes_; - ir::SpinLock registed_abstract_attributes_lock_; + pir::SpinLock registed_abstract_attributes_lock_; // AttributeStorage uniquer and cache instances. StorageManager registed_attribute_storage_manager_; // The dialect registered in the context. std::unordered_map registed_dialect_; - ir::SpinLock registed_dialect_lock_; + pir::SpinLock registed_dialect_lock_; // The Op registered in the context. OpInfoMap registed_op_infos_; - ir::SpinLock registed_op_infos_lock_; + pir::SpinLock registed_op_infos_lock_; - ir::SpinLock destructor_lock_; + pir::SpinLock destructor_lock_; }; IrContext *IrContext::Instance() { @@ -223,7 +223,7 @@ AbstractType *IrContext::GetRegisteredAbstractType(TypeId id) { } void IrContext::RegisterAbstractAttribute( - ir::TypeId type_id, AbstractAttribute &&abstract_attribute) { + pir::TypeId type_id, AbstractAttribute &&abstract_attribute) { if (GetRegisteredAbstractAttribute(type_id) == nullptr) { impl().RegisterAbstractAttribute( type_id, new AbstractAttribute(std::move(abstract_attribute))); @@ -274,7 +274,7 @@ Dialect *IrContext::GetRegisteredDialect(const std::string &dialect_name) { return nullptr; } -void IrContext::RegisterAbstractType(ir::TypeId type_id, +void IrContext::RegisterAbstractType(pir::TypeId type_id, AbstractType &&abstract_type) { if (GetRegisteredAbstractType(type_id) == nullptr) { impl().RegisterAbstractType(type_id, @@ -284,14 +284,15 @@ void IrContext::RegisterAbstractType(ir::TypeId type_id, } } -void IrContext::RegisterOpInfo(Dialect *dialect, - TypeId op_id, - const char *name, - std::vector &&interface_map, - const std::vector &trait_set, - size_t attributes_num, - const char **attributes_name, - VerifyPtr verify) { +void IrContext::RegisterOpInfo( + Dialect *dialect, + TypeId op_id, + const char *name, + std::vector &&interface_map, + const std::vector &trait_set, + size_t attributes_num, + const char **attributes_name, + VerifyPtr verify) { if (impl().IsOpInfoRegistered(name)) { LOG(WARNING) << name << " op already registered."; } else { @@ -361,4 +362,4 @@ Complex128Type Complex128Type::get(IrContext *ctx) { return ctx->impl().complex128_type; } -} // namespace ir +} // namespace pir diff --git a/paddle/ir/core/ir_context.h b/paddle/pir/core/ir_context.h similarity index 95% rename from paddle/ir/core/ir_context.h rename to paddle/pir/core/ir_context.h index ebec8d202ceb5..a68c87f3bee0b 100644 --- a/paddle/ir/core/ir_context.h +++ b/paddle/pir/core/ir_context.h @@ -18,9 +18,9 @@ #include #include -#include "paddle/ir/core/dll_decl.h" +#include "paddle/pir/core/dll_decl.h" -namespace ir { +namespace pir { class IrContextImpl; class StorageManager; class AbstractType; @@ -28,12 +28,13 @@ class AbstractAttribute; class TypeId; class Dialect; class OpInfo; -class InterfaceValue; class Type; class OpResult; class Attribute; class Operation; - +namespace details { +class InterfaceValue; +} using OpInfoMap = std::unordered_map; /// @@ -86,7 +87,7 @@ class IR_API IrContext { /// \param type_id The type id of the AbstractAttribute. /// \param abstract_attribute AbstractAttribute provided by user. /// - void RegisterAbstractAttribute(ir::TypeId type_id, + void RegisterAbstractAttribute(pir::TypeId type_id, AbstractAttribute &&abstract_attribute); /// @@ -109,7 +110,7 @@ class IR_API IrContext { void RegisterOpInfo(Dialect *dialect, TypeId op_id, const char *name, - std::vector &&interface_map, + std::vector &&interface_map, const std::vector &trait_set, size_t attributes_num, const char **attributes_name, @@ -190,4 +191,4 @@ class IR_API IrContext { IrContextImpl *impl_; }; -} // namespace ir +} // namespace pir diff --git a/paddle/ir/core/ir_printer.cc b/paddle/pir/core/ir_printer.cc similarity index 95% rename from paddle/ir/core/ir_printer.cc rename to paddle/pir/core/ir_printer.cc index 0d0ce64f679de..7fa8e076ad147 100644 --- a/paddle/ir/core/ir_printer.cc +++ b/paddle/pir/core/ir_printer.cc @@ -17,17 +17,17 @@ #include #include -#include "paddle/ir/core/block.h" -#include "paddle/ir/core/builtin_attribute.h" -#include "paddle/ir/core/builtin_type.h" -#include "paddle/ir/core/dialect.h" -#include "paddle/ir/core/ir_printer.h" -#include "paddle/ir/core/operation.h" -#include "paddle/ir/core/program.h" -#include "paddle/ir/core/utils.h" -#include "paddle/ir/core/value.h" - -namespace ir { +#include "paddle/pir/core/block.h" +#include "paddle/pir/core/builtin_attribute.h" +#include "paddle/pir/core/builtin_type.h" +#include "paddle/pir/core/dialect.h" +#include "paddle/pir/core/ir_printer.h" +#include "paddle/pir/core/operation.h" +#include "paddle/pir/core/program.h" +#include "paddle/pir/core/utils.h" +#include "paddle/pir/core/value.h" + +namespace pir { namespace { constexpr char newline[] = "\n"; // NOLINT @@ -334,4 +334,4 @@ std::ostream& operator<<(std::ostream& os, const Program& prog) { return os; } -} // namespace ir +} // namespace pir diff --git a/paddle/ir/core/ir_printer.h b/paddle/pir/core/ir_printer.h similarity index 86% rename from paddle/ir/core/ir_printer.h rename to paddle/pir/core/ir_printer.h index c393d2dfbe90a..a845bec52490c 100644 --- a/paddle/ir/core/ir_printer.h +++ b/paddle/pir/core/ir_printer.h @@ -18,15 +18,15 @@ #include #include -#include "paddle/ir/core/attribute.h" -#include "paddle/ir/core/block.h" -#include "paddle/ir/core/operation.h" -#include "paddle/ir/core/program.h" -#include "paddle/ir/core/region.h" -#include "paddle/ir/core/type.h" -#include "paddle/ir/core/value.h" +#include "paddle/pir/core/attribute.h" +#include "paddle/pir/core/block.h" +#include "paddle/pir/core/operation.h" +#include "paddle/pir/core/program.h" +#include "paddle/pir/core/region.h" +#include "paddle/pir/core/type.h" +#include "paddle/pir/core/value.h" -namespace ir { +namespace pir { class BasicIrPrinter { public: @@ -75,4 +75,4 @@ class IR_API IrPrinter : public BasicIrPrinter { std::unordered_map aliases_; }; -} // namespace ir +} // namespace pir diff --git a/paddle/ir/core/macros.h b/paddle/pir/core/macros.h similarity index 97% rename from paddle/ir/core/macros.h rename to paddle/pir/core/macros.h index 962ca6d4107f3..25d6dd5a812ab 100644 --- a/paddle/ir/core/macros.h +++ b/paddle/pir/core/macros.h @@ -13,7 +13,7 @@ // limitations under the License. #pragma once -namespace ir { +namespace pir { // TODO(Aurelius84): We also has DISABLE_COPY_AND_ASSIGN in phi/core/maros.h, // howere it's not recommended to use it in ir namspace. So we define this again // here. @@ -28,4 +28,4 @@ namespace ir { classname& operator=(classname&&) = delete #endif -} // namespace ir +} // namespace pir diff --git a/paddle/pir/core/op_base.h b/paddle/pir/core/op_base.h new file mode 100644 index 0000000000000..e51018d5c3f57 --- /dev/null +++ b/paddle/pir/core/op_base.h @@ -0,0 +1,151 @@ +// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once +#include + +#include "paddle/pir/core/enforce.h" +#include "paddle/pir/core/interface_support.h" +#include "paddle/pir/core/op_result.h" +#include "paddle/pir/core/operation.h" +#include "paddle/pir/core/utils.h" + +namespace pir { + +class IR_API OpBase { + public: + explicit OpBase(Operation *operation = nullptr) : operation_(operation) {} + + Operation *operation() const { + IR_ENFORCE(operation_, "Can't use operation() in a null op."); + return operation_; + } + + explicit operator bool() const { return operation_ != nullptr; } + + operator Operation *() const { return operation(); } + + Operation *operator->() const { return operation(); } + + IrContext *ir_context() const { return operation()->ir_context(); } + + uint32_t num_results() const { return operation()->num_results(); } + + uint32_t num_operands() const { return operation()->num_operands(); } + + const AttributeMap &attributes() const { return operation()->attributes(); } + + Value operand_source(uint32_t index) const { + return operation()->operand_source(index); + } + + OpResult result(uint32_t index) const { return operation()->result(index); } + + pir::Attribute attribute(const std::string &name) { + return operation()->attribute(name); + } + + template + T attribute(const std::string &name) { + return operation()->attribute(name); + } + + private: + Operation *operation_; // Not owned +}; + +/// +/// \brief OpTrait +/// +template +class OpTraitBase : public OpBase { + public: + explicit OpTraitBase(Operation *op) : OpBase(op) {} + + static TypeId GetTraitId() { return TypeId::get(); } + + static ConcreteTrait dyn_cast(Operation *op) { + if (op && op->HasTrait()) { + return ConcreteTrait(op); + } + return ConcreteTrait(nullptr); + } +}; + +/// +/// \brief OpInterface +/// +template +class OpInterfaceBase : public OpBase { + public: + explicit OpInterfaceBase(Operation *op) : OpBase(op) {} + + // Accessor for the ID of this interface. + static TypeId GetInterfaceId() { return TypeId::get(); } + + static ConcreteInterface dyn_cast(Operation *op) { + if (op && op->HasInterface()) { + return ConcreteInterface( + op, op->info().GetInterfaceImpl()); + } + return ConcreteInterface(nullptr, nullptr); + } +}; + +template +class Op : public OpBase { + public: + using OpBase::OpBase; + + using TraitList = + typename Filter>::Type; + + using InterfaceList = + typename Filter>::Type; + + static ConcreteOp dyn_cast(Operation *op) { + if (op && op->info().id() == TypeId::get()) { + return ConcreteOp(op); + } + return ConcreteOp(nullptr); + } + + static bool classof(const Operation *op) { + return op && op->info().id() == TypeId::get(); + } + + static std::vector GetInterfaceMap() { + return pir::details::GetInterfaceMap(); + } + + static std::vector GetTraitSet() { + return pir::details::GetTraitSet(); + } + + // Checking that the derived class does not define any member by comparing + // its size to an ad-hoc EmptyOp. + static constexpr bool HasNoDataMembers() { + class EmptyOp : public Op {}; + return sizeof(ConcreteOp) == sizeof(EmptyOp); + } + + // Implementation of `VerifyInvariantsFn` OperationName hook. + static void VerifyInvariants(Operation *op) { + static_assert(HasNoDataMembers(), + "Op class shouldn't define new data members"); + op->dyn_cast().Verify(); + } +}; + +} // namespace pir diff --git a/paddle/ir/core/op_info.cc b/paddle/pir/core/op_info.cc similarity index 87% rename from paddle/ir/core/op_info.cc rename to paddle/pir/core/op_info.cc index 6c9b62f56e63f..b018bec30448d 100644 --- a/paddle/ir/core/op_info.cc +++ b/paddle/pir/core/op_info.cc @@ -12,12 +12,12 @@ // See the License for the specific language governing permissions and // limitations under the License. -#include "paddle/ir/core/op_info.h" -#include "paddle/ir/core/dialect.h" -#include "paddle/ir/core/ir_context.h" -#include "paddle/ir/core/op_info_impl.h" +#include "paddle/pir/core/op_info.h" +#include "paddle/pir/core/dialect.h" +#include "paddle/pir/core/ir_context.h" +#include "paddle/pir/core/op_info_impl.h" -namespace ir { +namespace pir { bool OpInfo::HasTrait(TypeId trait_id) const { return impl_ && impl_->HasTrait(trait_id); } @@ -40,4 +40,4 @@ void OpInfo::Verify(Operation *operation) const { impl_->verify()(operation); } void *OpInfo::GetInterfaceImpl(TypeId interface_id) const { return impl_ ? impl_->GetInterfaceImpl(interface_id) : nullptr; } -} // namespace ir +} // namespace pir diff --git a/paddle/ir/core/op_info.h b/paddle/pir/core/op_info.h similarity index 70% rename from paddle/ir/core/op_info.h rename to paddle/pir/core/op_info.h index f92d37d4b33e0..130c05037d8ae 100644 --- a/paddle/ir/core/op_info.h +++ b/paddle/pir/core/op_info.h @@ -16,9 +16,9 @@ #include #include -#include "paddle/ir/core/type_id.h" +#include "paddle/pir/core/type_id.h" -namespace ir { +namespace pir { class OpInfoImpl; class IrContext; class OpResult; @@ -61,15 +61,15 @@ class IR_API OpInfo { bool HasTrait(TypeId trait_id) const; - template + template bool HasInterface() const { - return HasInterface(TypeId::get()); + return HasInterface(TypeId::get()); } bool HasInterface(TypeId interface_id) const; - template - typename Interface::Concept *GetInterfaceImpl() const; + template + typename InterfaceT::Concept *GetInterfaceImpl() const; void *AsOpaquePointer() const { return impl_; } static OpInfo RecoverFromOpaquePointer(void *pointer) { @@ -84,22 +84,28 @@ class IR_API OpInfo { void *GetInterfaceImpl(TypeId interface_id) const; private: - OpInfoImpl *impl_{nullptr}; // not owned + /// The internal implementation of the operation name. + /// Not owned. + OpInfoImpl *impl_{nullptr}; }; -template -typename Interface::Concept *OpInfo::GetInterfaceImpl() const { - void *model = GetInterfaceImpl(TypeId::get()); - return reinterpret_cast(model); +/// +/// \brief Returns an instance of the concept object for the given interface if +/// it was registered to this operation, null otherwise. +/// +template +typename InterfaceT::Concept *OpInfo::GetInterfaceImpl() const { + void *model = GetInterfaceImpl(TypeId::get()); + return reinterpret_cast(model); } -} // namespace ir +} // namespace pir namespace std { template <> -struct hash { - std::size_t operator()(const ir::OpInfo &obj) const { - return std::hash()(obj.impl_); +struct hash { + std::size_t operator()(const pir::OpInfo &obj) const { + return std::hash()(obj.impl_); } }; } // namespace std diff --git a/paddle/ir/core/op_info_impl.cc b/paddle/pir/core/op_info_impl.cc similarity index 69% rename from paddle/ir/core/op_info_impl.cc rename to paddle/pir/core/op_info_impl.cc index 90469f1731be9..fa91d3173389a 100644 --- a/paddle/ir/core/op_info_impl.cc +++ b/paddle/pir/core/op_info_impl.cc @@ -12,14 +12,15 @@ // See the License for the specific language governing permissions and // limitations under the License. -#include "paddle/ir/core/op_info_impl.h" -#include "paddle/ir/core/dialect.h" +#include "paddle/pir/core/op_info_impl.h" +#include "paddle/pir/core/dialect.h" +#include "paddle/pir/core/interface_support.h" -namespace ir { +namespace pir { OpInfo OpInfoImpl::Create(Dialect *dialect, TypeId op_id, const char *op_name, - std::vector &&interface_map, + std::vector &&interface_map, const std::vector &trait_set, size_t attributes_num, const char *attributes_name[], // NOLINT @@ -29,7 +30,7 @@ OpInfo OpInfoImpl::Create(Dialect *dialect, size_t traits_num = trait_set.size(); VLOG(6) << "Create OpInfoImpl with: " << interfaces_num << " interfaces, " << traits_num << " traits, " << attributes_num << " attributes."; - size_t base_size = sizeof(InterfaceValue) * interfaces_num + + size_t base_size = sizeof(details::InterfaceValue) * interfaces_num + sizeof(TypeId) * traits_num + sizeof(OpInfoImpl); char *base_ptr = static_cast(::operator new(base_size)); VLOG(6) << "Malloc " << base_size << " Bytes at " @@ -37,10 +38,10 @@ OpInfo OpInfoImpl::Create(Dialect *dialect, if (interfaces_num > 0) { std::sort(interface_map.begin(), interface_map.end()); for (size_t index = 0; index < interfaces_num; ++index) { - new (base_ptr + index * sizeof(InterfaceValue)) - InterfaceValue(std::move(interface_map[index])); + new (base_ptr + index * sizeof(details::InterfaceValue)) + details::InterfaceValue(std::move(interface_map[index])); } - base_ptr += interfaces_num * sizeof(InterfaceValue); + base_ptr += interfaces_num * sizeof(details::InterfaceValue); } if (traits_num > 0) { auto p_first_trait = reinterpret_cast(base_ptr); @@ -69,7 +70,7 @@ void OpInfoImpl::Destroy(OpInfo info) { } } -ir::IrContext *OpInfoImpl::ir_context() const { +pir::IrContext *OpInfoImpl::ir_context() const { return dialect_ ? dialect_->ir_context() : nullptr; } @@ -77,7 +78,7 @@ bool OpInfoImpl::HasTrait(TypeId trait_id) const { if (num_traits_ > 0) { const TypeId *p_first_trait = reinterpret_cast(reinterpret_cast(this) - - sizeof(ir::TypeId) * num_traits_); + sizeof(pir::TypeId) * num_traits_); return std::binary_search( p_first_trait, p_first_trait + num_traits_, trait_id); } @@ -86,49 +87,32 @@ bool OpInfoImpl::HasTrait(TypeId trait_id) const { bool OpInfoImpl::HasInterface(TypeId interface_id) const { if (num_interfaces_ > 0) { - const InterfaceValue *p_first_interface = - reinterpret_cast( + const details::InterfaceValue *p_first_interface = + reinterpret_cast( reinterpret_cast(this) - - sizeof(ir::TypeId) * num_traits_ - - sizeof(InterfaceValue) * num_interfaces_); + sizeof(pir::TypeId) * num_traits_ - + sizeof(details::InterfaceValue) * num_interfaces_); return std::binary_search(p_first_interface, p_first_interface + num_interfaces_, - InterfaceValue(interface_id)); + details::InterfaceValue(interface_id)); } return false; } void *OpInfoImpl::GetInterfaceImpl(TypeId interface_id) const { - if (num_interfaces_ > 0) { - const InterfaceValue *p_first_interface = - reinterpret_cast( - reinterpret_cast(this) - - sizeof(TypeId) * num_traits_ - - sizeof(InterfaceValue) * num_interfaces_); - size_t left = 0, right = num_interfaces_; - while (left < right) { - size_t mid = (left + right) / 2; - if ((p_first_interface + mid)->type_id() == interface_id) { - return (p_first_interface + mid)->model(); - } else if ((p_first_interface + mid)->type_id() < interface_id) { - left = mid + 1; - } else { - right = mid; - } - } - } - return nullptr; + return pir::details::LookUp( + interface_id, num_interfaces_, num_traits_, this); } void OpInfoImpl::Destroy() { VLOG(10) << "Destroy op_info impl at " << this; // (1) free interfaces char *base_ptr = reinterpret_cast(this) - - sizeof(ir::TypeId) * num_traits_ - - sizeof(InterfaceValue) * num_interfaces_; + sizeof(pir::TypeId) * num_traits_ - + sizeof(details::InterfaceValue) * num_interfaces_; if (num_interfaces_ > 0) { - InterfaceValue *p_interface_val = - reinterpret_cast(base_ptr); + details::InterfaceValue *p_interface_val = + reinterpret_cast(base_ptr); for (size_t i = 0; i < num_interfaces_; i++) { (p_interface_val + i)->~InterfaceValue(); } @@ -138,4 +122,4 @@ void OpInfoImpl::Destroy() { free(base_ptr); } -} // namespace ir +} // namespace pir diff --git a/paddle/ir/core/op_info_impl.h b/paddle/pir/core/op_info_impl.h similarity index 91% rename from paddle/ir/core/op_info_impl.h rename to paddle/pir/core/op_info_impl.h index 52666f1b377c8..410c9ef371989 100644 --- a/paddle/ir/core/op_info_impl.h +++ b/paddle/pir/core/op_info_impl.h @@ -19,11 +19,11 @@ #include #include -#include "paddle/ir/core/builtin_attribute.h" -#include "paddle/ir/core/op_base.h" -#include "paddle/ir/core/type.h" +#include "paddle/pir/core/builtin_attribute.h" +#include "paddle/pir/core/op_base.h" +#include "paddle/pir/core/type.h" -namespace ir { +namespace pir { class Dialect; /// @@ -38,7 +38,7 @@ class OpInfoImpl { static OpInfo Create(Dialect *dialect, TypeId op_id, const char *op_name, - std::vector &&interface_map, + std::vector &&interface_map, const std::vector &trait_set, size_t attributes_num, const char *attributes_name[], @@ -69,7 +69,7 @@ class OpInfoImpl { } private: - OpInfoImpl(ir::Dialect *dialect, + OpInfoImpl(pir::Dialect *dialect, TypeId op_id, const char *op_name, uint32_t num_interfaces, @@ -111,4 +111,4 @@ class OpInfoImpl { VerifyPtr verify_{nullptr}; }; -} // namespace ir +} // namespace pir diff --git a/paddle/pir/core/op_operand.cc b/paddle/pir/core/op_operand.cc new file mode 100644 index 0000000000000..b27f02ac23d4c --- /dev/null +++ b/paddle/pir/core/op_operand.cc @@ -0,0 +1,70 @@ +// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +#include "paddle/pir/core/op_operand.h" +#include "paddle/pir/core/enforce.h" +#include "paddle/pir/core/op_operand_impl.h" + +#define CHECK_NULL_IMPL(class_name, func_name) \ + IR_ENFORCE(impl_, \ + "impl_ pointer is null when call func:" #func_name \ + " , in class: " #class_name ".") + +#define CHECK_OPOPEREND_NULL_IMPL(func_name) \ + CHECK_NULL_IMPL(OpOpernad, func_name) + +namespace pir { + +OpOperand::OpOperand(const detail::OpOperandImpl *impl) + : impl_(const_cast(impl)) {} + +OpOperand &OpOperand::operator=(const OpOperand &rhs) { + impl_ = rhs.impl_; + return *this; +} + +OpOperand &OpOperand::operator=(const detail::OpOperandImpl *impl) { + if (this->impl_ == impl) return *this; + impl_ = const_cast(impl); + return *this; +} +OpOperand::operator bool() const { return impl_ && impl_->source(); } + +OpOperand OpOperand::next_use() const { + CHECK_OPOPEREND_NULL_IMPL(next_use); + return impl_->next_use(); +} + +Value OpOperand::source() const { + CHECK_OPOPEREND_NULL_IMPL(source); + return impl_->source(); +} + +Type OpOperand::type() const { return source().type(); } + +void OpOperand::set_source(Value value) { + CHECK_OPOPEREND_NULL_IMPL(set_source); + impl_->set_source(value); +} + +Operation *OpOperand::owner() const { + CHECK_OPOPEREND_NULL_IMPL(owner); + return impl_->owner(); +} + +void OpOperand::RemoveFromUdChain() { + CHECK_OPOPEREND_NULL_IMPL(RemoveFromUdChain); + return impl_->RemoveFromUdChain(); +} + +} // namespace pir diff --git a/paddle/pir/core/op_operand.h b/paddle/pir/core/op_operand.h new file mode 100644 index 0000000000000..96b355b861ffa --- /dev/null +++ b/paddle/pir/core/op_operand.h @@ -0,0 +1,69 @@ +// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include "paddle/pir/core/dll_decl.h" + +namespace pir { +class Operation; +class Value; +class Type; + +namespace detail { +class OpOperandImpl; +} // namespace detail + +/// +/// \brief OpOperand class represents the op_operand of operation. This class +/// only provides interfaces, for specific implementation, see Impl class. +/// +class IR_API OpOperand { + public: + OpOperand() = default; + + OpOperand(const OpOperand &other) = default; + + OpOperand(const detail::OpOperandImpl *impl); // NOLINT + + OpOperand &operator=(const OpOperand &rhs); + + OpOperand &operator=(const detail::OpOperandImpl *impl); + + bool operator==(const OpOperand &other) const { return impl_ == other.impl_; } + + bool operator!=(const OpOperand &other) const { return !operator==(other); } + + bool operator!() const { return impl_ == nullptr; } + + operator bool() const; + + OpOperand next_use() const; + + Value source() const; + + Type type() const; + + void set_source(Value value); + + Operation *owner() const; + + void RemoveFromUdChain(); + + friend Operation; + + private: + detail::OpOperandImpl *impl_{nullptr}; +}; +} // namespace pir diff --git a/paddle/pir/core/op_operand_impl.cc b/paddle/pir/core/op_operand_impl.cc new file mode 100644 index 0000000000000..44a3a5f28bb6e --- /dev/null +++ b/paddle/pir/core/op_operand_impl.cc @@ -0,0 +1,74 @@ +// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +#include "paddle/pir/core/op_operand_impl.h" +#include "paddle/pir/core/value_impl.h" + +namespace pir { +namespace detail { + +pir::Operation *OpOperandImpl::owner() const { return owner_; } + +pir::detail::OpOperandImpl *OpOperandImpl::next_use() { return next_use_; } + +pir::Value OpOperandImpl::source() const { return source_; } + +void OpOperandImpl::set_source(Value source) { + RemoveFromUdChain(); + if (!source) { + return; + } + source_ = source; + InsertToUdChain(); +} + +OpOperandImpl::OpOperandImpl(pir::Value source, pir::Operation *owner) + : source_(source), owner_(owner) { + if (!source) { + return; + } + InsertToUdChain(); +} + +void OpOperandImpl::InsertToUdChain() { + prev_use_addr_ = source_.impl()->first_use_addr(); + next_use_ = source_.impl()->first_use(); + if (next_use_) { + next_use_->prev_use_addr_ = &next_use_; + } + source_.impl()->set_first_use(this); +} + +void OpOperandImpl::RemoveFromUdChain() { + if (!source_) return; + if (!prev_use_addr_) return; + if (prev_use_addr_ == source_.impl()->first_use_addr()) { + /// NOTE: In ValueImpl, first_use_offseted_by_index_ use lower three bits + /// storage index information, so need to be updated using the set_first_use + /// method here. + source_.impl()->set_first_use(next_use_); + } else { + *prev_use_addr_ = next_use_; + } + if (next_use_) { + next_use_->prev_use_addr_ = prev_use_addr_; + } + next_use_ = nullptr; + prev_use_addr_ = nullptr; + source_ = nullptr; +} + +OpOperandImpl::~OpOperandImpl() { RemoveFromUdChain(); } + +} // namespace detail +} // namespace pir diff --git a/paddle/pir/core/op_operand_impl.h b/paddle/pir/core/op_operand_impl.h new file mode 100644 index 0000000000000..f1bc9d23c0928 --- /dev/null +++ b/paddle/pir/core/op_operand_impl.h @@ -0,0 +1,60 @@ +// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once +#include "paddle/pir/core/value.h" + +namespace pir { + +class Operation; + +namespace detail { +/// +/// \brief OpOperandImpl +/// +class OpOperandImpl { + public: + Operation *owner() const; + + OpOperandImpl *next_use(); + + Value source() const; + + void set_source(Value value); + + /// Remove this op_operand from the current use list. + void RemoveFromUdChain(); + + ~OpOperandImpl(); + + friend Operation; + + private: + OpOperandImpl(Value source, Operation *owner); + + // Insert self to the UD chain holded by source_; + // It is not safe. So set private. + void InsertToUdChain(); + + Value source_; + + OpOperandImpl *next_use_ = nullptr; + + OpOperandImpl **prev_use_addr_ = nullptr; + + Operation *const owner_ = nullptr; +}; + +} // namespace detail +} // namespace pir diff --git a/paddle/pir/core/op_result.cc b/paddle/pir/core/op_result.cc new file mode 100644 index 0000000000000..510f98d99b526 --- /dev/null +++ b/paddle/pir/core/op_result.cc @@ -0,0 +1,56 @@ +// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +#include "paddle/pir/core/op_result.h" +#include "paddle/pir/core/enforce.h" +#include "paddle/pir/core/op_result_impl.h" + +#define CHECK_NULL_IMPL(class_name, func_name) \ + IR_ENFORCE(impl_, \ + "impl_ pointer is null when call func:" #func_name \ + " , in class: " #class_name ".") + +#define CHECK_OPRESULT_NULL_IMPL(func_name) CHECK_NULL_IMPL(OpResult, func_name) + +namespace pir { + +// OpResult +bool OpResult::classof(Value value) { + return value && pir::isa(value.impl()); +} + +Operation *OpResult::owner() const { + CHECK_OPRESULT_NULL_IMPL(owner); + return impl()->owner(); +} + +uint32_t OpResult::GetResultIndex() const { + CHECK_OPRESULT_NULL_IMPL(GetResultIndex); + return impl()->GetResultIndex(); +} + +detail::OpResultImpl *OpResult::impl() const { + return reinterpret_cast(impl_); +} + +bool OpResult::operator==(const OpResult &other) const { + return impl_ == other.impl_; +} + +uint32_t OpResult::GetValidInlineIndex(uint32_t index) { + uint32_t max_inline_index = + pir::detail::OpResultImpl::GetMaxInlineResultIndex(); + return index <= max_inline_index ? index : max_inline_index; +} + +} // namespace pir diff --git a/paddle/pir/core/op_result.h b/paddle/pir/core/op_result.h new file mode 100644 index 0000000000000..1a5f14a9a17fe --- /dev/null +++ b/paddle/pir/core/op_result.h @@ -0,0 +1,49 @@ +// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include "paddle/pir/core/value.h" +namespace pir { + +namespace detail { +class OpResultImpl; +} // namespace detail + +/// +/// \brief OpResult class represents the value defined by a result of operation. +/// This class only provides interfaces, for specific implementation, see Impl +/// class. +/// +class IR_API OpResult : public Value { + public: + using Value::Value; + + static bool classof(Value value); + + Operation *owner() const; + + uint32_t GetResultIndex() const; + + bool operator==(const OpResult &other) const; + + friend Operation; + + detail::OpResultImpl *impl() const; + + private: + static uint32_t GetValidInlineIndex(uint32_t index); +}; + +} // namespace pir diff --git a/paddle/pir/core/op_result_impl.cc b/paddle/pir/core/op_result_impl.cc new file mode 100644 index 0000000000000..49b9f40259845 --- /dev/null +++ b/paddle/pir/core/op_result_impl.cc @@ -0,0 +1,53 @@ +// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +#include "paddle/pir/core/op_result_impl.h" + +#include + +namespace pir { +namespace detail { + +uint32_t OpResultImpl::GetResultIndex() const { + if (const auto *outline_result = dyn_cast(this)) { + return outline_result->GetResultIndex(); + } + return dyn_cast(this)->GetResultIndex(); +} + +OpResultImpl::~OpResultImpl() { assert(use_empty()); } + +Operation *OpResultImpl::owner() const { + // For inline result, pointer offset index to obtain the address of op. + if (const auto *result = dyn_cast(this)) { + result += result->GetResultIndex() + 1; + return reinterpret_cast( + const_cast(result)); + } + // For outline result, pointer offset outline_index to obtain the address of + // maximum inline result. + const OpOutlineResultImpl *outline_result = + (const OpOutlineResultImpl *)(this); + outline_result += + (outline_result->outline_index_ - GetMaxInlineResultIndex()); + // The offset of the maximum inline result distance op is + // GetMaxInlineResultIndex. + const auto *inline_result = + reinterpret_cast(outline_result); + inline_result += (GetMaxInlineResultIndex() + 1); + return reinterpret_cast( + const_cast(inline_result)); +} + +} // namespace detail +} // namespace pir diff --git a/paddle/pir/core/op_result_impl.h b/paddle/pir/core/op_result_impl.h new file mode 100644 index 0000000000000..99601a27911af --- /dev/null +++ b/paddle/pir/core/op_result_impl.h @@ -0,0 +1,93 @@ +// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include "paddle/pir/core/value_impl.h" + +namespace pir { +namespace detail { +/// +/// \brief OpResultImpl is the implementation of an operation result. +/// +class OpResultImpl : public ValueImpl { + public: + using ValueImpl::ValueImpl; + + static bool classof(const ValueImpl &value) { + return value.kind() <= OUTLINE_OP_RESULT_INDEX; + } + + /// + /// \brief Get the parent operation of this result.(op_ptr = value_ptr + + /// index) + /// + Operation *owner() const; + + /// + /// \brief Get the result index of the operation result. + /// + uint32_t GetResultIndex() const; + + /// + /// \brief Get the maximum number of results that can be stored inline. + /// + static uint32_t GetMaxInlineResultIndex() { + return OUTLINE_OP_RESULT_INDEX - 1; + } + + ~OpResultImpl(); +}; + +/// +/// \brief OpInlineResultImpl is the implementation of an operation result whose +/// index <= 5. +/// +class OpInlineResultImpl : public OpResultImpl { + public: + OpInlineResultImpl(Type type, uint32_t result_index) + : OpResultImpl(type, result_index) { + if (result_index > GetMaxInlineResultIndex()) { + throw("Inline result index should not exceed MaxInlineResultIndex(5)"); + } + } + + static bool classof(const OpResultImpl &value) { + return value.kind() < OUTLINE_OP_RESULT_INDEX; + } + + uint32_t GetResultIndex() const { return kind(); } +}; + +/// +/// \brief OpOutlineResultImpl is the implementation of an operation result +/// whose index > 5. +/// +class OpOutlineResultImpl : public OpResultImpl { + public: + OpOutlineResultImpl(Type type, uint32_t outline_index) + : OpResultImpl(type, OUTLINE_OP_RESULT_INDEX), + outline_index_(outline_index) {} + + static bool classof(const OpResultImpl &value) { + return value.kind() == OUTLINE_OP_RESULT_INDEX; + } + + uint32_t GetResultIndex() const { return outline_index_; } + + uint32_t outline_index_; +}; + +} // namespace detail +} // namespace pir diff --git a/paddle/ir/core/operation.cc b/paddle/pir/core/operation.cc similarity index 93% rename from paddle/ir/core/operation.cc rename to paddle/pir/core/operation.cc index 3d316847d9fc1..fdb850bc1f415 100644 --- a/paddle/ir/core/operation.cc +++ b/paddle/pir/core/operation.cc @@ -14,18 +14,18 @@ #include -#include "paddle/ir/core/block.h" -#include "paddle/ir/core/block_operand_impl.h" -#include "paddle/ir/core/dialect.h" -#include "paddle/ir/core/enforce.h" -#include "paddle/ir/core/op_info.h" -#include "paddle/ir/core/operation.h" -#include "paddle/ir/core/program.h" -#include "paddle/ir/core/region.h" -#include "paddle/ir/core/utils.h" -#include "paddle/ir/core/value_impl.h" - -namespace ir { +#include "paddle/pir/core/block.h" +#include "paddle/pir/core/block_operand_impl.h" +#include "paddle/pir/core/dialect.h" +#include "paddle/pir/core/enforce.h" +#include "paddle/pir/core/op_info.h" +#include "paddle/pir/core/op_result_impl.h" +#include "paddle/pir/core/operation.h" +#include "paddle/pir/core/program.h" +#include "paddle/pir/core/region.h" +#include "paddle/pir/core/utils.h" + +namespace pir { Operation *Operation::Create(OperationArgument &&argument) { return Create(argument.inputs, argument.attributes, @@ -38,10 +38,10 @@ Operation *Operation::Create(OperationArgument &&argument) { // Allocate the required memory based on the size and number of inputs, outputs, // and operators, and construct it in the order of: OpOutlineResult, // OpInlineResult, Operation, operand. -Operation *Operation::Create(const std::vector &inputs, +Operation *Operation::Create(const std::vector &inputs, const AttributeMap &attributes, const std::vector &output_types, - ir::OpInfo op_info, + pir::OpInfo op_info, size_t num_regions, const std::vector &successors) { // 1. Calculate the required memory size for OpResults + Operation + @@ -179,7 +179,7 @@ IrContext *Operation::ir_context() const { return info_.ir_context(); } Dialect *Operation::dialect() const { return info_.dialect(); } Operation::Operation(const AttributeMap &attributes, - ir::OpInfo op_info, + pir::OpInfo op_info, uint32_t num_results, uint32_t num_operands, uint32_t num_regions, @@ -191,7 +191,7 @@ Operation::Operation(const AttributeMap &attributes, num_regions_(num_regions), num_successors_(num_successors) {} -ir::OpResult Operation::result(uint32_t index) const { +pir::OpResult Operation::result(uint32_t index) const { if (index >= num_results_) { IR_THROW("index exceeds OP output range."); } @@ -204,10 +204,10 @@ ir::OpResult Operation::result(uint32_t index) const { : reinterpret_cast(this) - (index + 1) * sizeof(detail::OpInlineResultImpl); if (index > max_inline_idx) { - return ir::OpResult( + return pir::OpResult( reinterpret_cast(ptr)); } else { - return ir::OpResult( + return pir::OpResult( reinterpret_cast(ptr)); } } @@ -318,4 +318,4 @@ std::vector Operation::results() const { return res; } -} // namespace ir +} // namespace pir diff --git a/paddle/ir/core/operation.h b/paddle/pir/core/operation.h similarity index 90% rename from paddle/ir/core/operation.h rename to paddle/pir/core/operation.h index 961e4a5fccc50..28c0b42671c96 100644 --- a/paddle/ir/core/operation.h +++ b/paddle/pir/core/operation.h @@ -16,14 +16,14 @@ #include #include -#include "paddle/ir/core/block.h" -#include "paddle/ir/core/enforce.h" -#include "paddle/ir/core/macros.h" -#include "paddle/ir/core/op_info.h" -#include "paddle/ir/core/operation_utils.h" -#include "paddle/ir/core/type.h" - -namespace ir { +#include "paddle/pir/core/block.h" +#include "paddle/pir/core/enforce.h" +#include "paddle/pir/core/macros.h" +#include "paddle/pir/core/op_info.h" +#include "paddle/pir/core/operation_utils.h" +#include "paddle/pir/core/type.h" + +namespace pir { class OpBase; class Program; class OpOperand; @@ -41,10 +41,10 @@ class IR_API alignas(8) Operation final { /// NOTE: Similar to new and delete, the destroy() and the create() need to be /// used in conjunction. /// - static Operation *Create(const std::vector &inputs, + static Operation *Create(const std::vector &inputs, const AttributeMap &attributes, - const std::vector &output_types, - ir::OpInfo op_info, + const std::vector &output_types, + pir::OpInfo op_info, size_t num_regions = 0, const std::vector &successors = {}); static Operation *Create(OperationArgument &&op_argument); @@ -96,7 +96,7 @@ class IR_API alignas(8) Operation final { return attributes_.find(key) != attributes_.end(); } - ir::OpInfo info() const { return info_; } + pir::OpInfo info() const { return info_; } uint32_t num_results() const { return num_results_; } @@ -164,7 +164,7 @@ class IR_API alignas(8) Operation final { private: DISABLE_COPY_AND_ASSIGN(Operation); Operation(const AttributeMap &attribute, - ir::OpInfo op_info, + pir::OpInfo op_info, uint32_t num_results, uint32_t num_operands, uint32_t num_regions, @@ -203,4 +203,4 @@ class IR_API alignas(8) Operation final { Block::iterator position_; }; -} // namespace ir +} // namespace pir diff --git a/paddle/ir/core/operation_utils.cc b/paddle/pir/core/operation_utils.cc similarity index 83% rename from paddle/ir/core/operation_utils.cc rename to paddle/pir/core/operation_utils.cc index f975de0c82807..a8eedcfcb8c48 100644 --- a/paddle/ir/core/operation_utils.cc +++ b/paddle/pir/core/operation_utils.cc @@ -12,13 +12,13 @@ // See the License for the specific language governing permissions and // limitations under the License. -#include "paddle/ir/core/operation_utils.h" -#include "paddle/ir/core/ir_context.h" -#include "paddle/ir/core/region.h" +#include "paddle/pir/core/operation_utils.h" +#include "paddle/pir/core/ir_context.h" +#include "paddle/pir/core/region.h" -namespace ir { +namespace pir { OperationArgument::OperationArgument(IrContext* ir_context, const std::string& name) { info = ir_context->GetRegisteredOpInfo(name); } -} // namespace ir +} // namespace pir diff --git a/paddle/ir/core/operation_utils.h b/paddle/pir/core/operation_utils.h similarity index 92% rename from paddle/ir/core/operation_utils.h rename to paddle/pir/core/operation_utils.h index 9e317a6510f59..39c41c6eae2c3 100644 --- a/paddle/ir/core/operation_utils.h +++ b/paddle/pir/core/operation_utils.h @@ -15,13 +15,14 @@ #pragma once #include -#include "paddle/ir/core/attribute.h" -#include "paddle/ir/core/op_info.h" -#include "paddle/ir/core/region.h" -#include "paddle/ir/core/type.h" -#include "paddle/ir/core/value.h" +#include "paddle/pir/core/attribute.h" +#include "paddle/pir/core/op_info.h" +#include "paddle/pir/core/op_result.h" +#include "paddle/pir/core/region.h" +#include "paddle/pir/core/type.h" +#include "paddle/pir/core/value.h" -namespace ir { +namespace pir { class Block; using AttributeMap = std::unordered_map; @@ -100,4 +101,4 @@ void OperationArgument::AddAttributes(InputIt first, InputIt last) { } } -} // namespace ir +} // namespace pir diff --git a/paddle/ir/core/parameter.h b/paddle/pir/core/parameter.h similarity index 92% rename from paddle/ir/core/parameter.h rename to paddle/pir/core/parameter.h index 3dbe48935b09a..332ef23322e01 100644 --- a/paddle/ir/core/parameter.h +++ b/paddle/pir/core/parameter.h @@ -14,15 +14,15 @@ #pragma once -#include "paddle/ir/core/type.h" +#include "paddle/pir/core/type.h" -namespace ir { +namespace pir { /// /// \brief Parameter represents the weight in the calculation graph. /// class IR_API Parameter { public: - Parameter(void* data, size_t size, ir::Type type) { + Parameter(void* data, size_t size, pir::Type type) { data_ = malloc(size); memcpy(data_, data, size); size_ = size; @@ -67,4 +67,4 @@ class IR_API Parameter { bool is_mutable_ = false; }; -} // namespace ir +} // namespace pir diff --git a/paddle/ir/core/parser/ir_parser.cc b/paddle/pir/core/parser/ir_parser.cc similarity index 96% rename from paddle/ir/core/parser/ir_parser.cc rename to paddle/pir/core/parser/ir_parser.cc index 8d7e437635165..3fe336fc63289 100644 --- a/paddle/ir/core/parser/ir_parser.cc +++ b/paddle/pir/core/parser/ir_parser.cc @@ -12,12 +12,12 @@ // See the License for the specific language governing permissions and // limitations under the License. -#include "paddle/ir/core/ir_parser.h" +#include "paddle/pir/core/parser/ir_parser.h" -#include "paddle/ir/core/builtin_dialect.h" -#include "paddle/ir/core/builtin_type.h" +#include "paddle/pir/core/builtin_dialect.h" +#include "paddle/pir/core/builtin_type.h" -namespace ir { +namespace pir { IrParser::IrParser(IrContext* ctx, std::istream& is) { lexer.reset(new Lexer{is}); this->ctx = ctx; @@ -216,9 +216,9 @@ Operation* IrParser::ParseOperation() { OpInfo opinfo = ParseOpInfo(); - std::vector inputs = ParseOpRandList(); + std::vector inputs = ParseOprandList(); - ir::AttributeMap attributeMap = ParseAttributeMap(); + pir::AttributeMap attributeMap = ParseAttributeMap(); ConsumeAToken(":"); ConsumeAToken("("); @@ -269,7 +269,7 @@ OpInfo IrParser::ParseOpInfo() { // OprandList := ValueList // ValueList := ValueId(,ValueId)* -std::vector IrParser::ParseOpRandList() { +std::vector IrParser::ParseOprandList() { ConsumeAToken("("); std::vector inputs{}; Token ind_token = ConsumeToken(); @@ -348,4 +348,4 @@ std::unique_ptr Program::Parse(std::istream& is, IrContext* ctx) { return parser.ParseProgram(); } -} // namespace ir +} // namespace pir diff --git a/paddle/ir/core/ir_parser.h b/paddle/pir/core/parser/ir_parser.h similarity index 80% rename from paddle/ir/core/ir_parser.h rename to paddle/pir/core/parser/ir_parser.h index dbba3e2aaba80..c10e88225984b 100644 --- a/paddle/ir/core/ir_parser.h +++ b/paddle/pir/core/parser/ir_parser.h @@ -13,16 +13,16 @@ // limitations under the License. #pragma once -#include "paddle/ir/core/ir_context.h" -#include "paddle/ir/core/operation.h" -#include "paddle/ir/core/parser/lexer.h" -#include "paddle/ir/core/program.h" +#include "paddle/pir/core/ir_context.h" +#include "paddle/pir/core/operation.h" +#include "paddle/pir/core/parser/lexer.h" +#include "paddle/pir/core/program.h" -using OpResultMap = std::map; -using AttributeMap = std::unordered_map; +using OpResultMap = std::map; +using AttributeMap = std::unordered_map; using OpAttributeInfoMap = std::map; -namespace ir { +namespace pir { class IrParser { public: std::unique_ptr lexer; @@ -51,7 +51,7 @@ class IrParser { std::vector ParseOpResultList(); - std::vector ParseOpRandList(); + std::vector ParseOprandList(); AttributeMap ParseAttributeMap(); @@ -68,4 +68,4 @@ class IrParser { void ConsumeAToken(std::string expect_token_val); }; -} // namespace ir +} // namespace pir diff --git a/paddle/ir/core/parser/lexer.cc b/paddle/pir/core/parser/lexer.cc similarity index 99% rename from paddle/ir/core/parser/lexer.cc rename to paddle/pir/core/parser/lexer.cc index af1530a5b2961..c7f037de9927d 100644 --- a/paddle/ir/core/parser/lexer.cc +++ b/paddle/pir/core/parser/lexer.cc @@ -12,7 +12,7 @@ // See the License for the specific language governing permissions and // limitations under the License. -#include "paddle/ir/core/parser/lexer.h" +#include "paddle/pir/core/parser/lexer.h" Token Lexer::ConsumeToken() { SkipWhitespace(); diff --git a/paddle/ir/core/parser/lexer.h b/paddle/pir/core/parser/lexer.h similarity index 96% rename from paddle/ir/core/parser/lexer.h rename to paddle/pir/core/parser/lexer.h index 0561e1f60caa8..24694eb761317 100644 --- a/paddle/ir/core/parser/lexer.h +++ b/paddle/pir/core/parser/lexer.h @@ -16,7 +16,7 @@ #include #include -#include "paddle/ir/core/parser/token.h" +#include "paddle/pir/core/parser/token.h" class Lexer { private: diff --git a/paddle/ir/core/parser/token.h b/paddle/pir/core/parser/token.h similarity index 100% rename from paddle/ir/core/parser/token.h rename to paddle/pir/core/parser/token.h diff --git a/paddle/ir/core/program.cc b/paddle/pir/core/program.cc similarity index 90% rename from paddle/ir/core/program.cc rename to paddle/pir/core/program.cc index baf6a3cbdd57c..d4197a4a9bc4b 100644 --- a/paddle/ir/core/program.cc +++ b/paddle/pir/core/program.cc @@ -12,10 +12,10 @@ // See the License for the specific language governing permissions and // limitations under the License. -#include "paddle/ir/core/program.h" -#include "paddle/ir/core/ir_context.h" +#include "paddle/pir/core/program.h" +#include "paddle/pir/core/ir_context.h" -namespace ir { +namespace pir { Program::Program(IrContext* context) { module_ = ModuleOp::Create(context, this); @@ -39,4 +39,4 @@ void Program::SetParameter(const std::string& name, parameters_[name].reset(parameter.release()); } -} // namespace ir +} // namespace pir diff --git a/paddle/ir/core/program.h b/paddle/pir/core/program.h similarity index 89% rename from paddle/ir/core/program.h rename to paddle/pir/core/program.h index bf9c37210967e..8756b3aa70e1c 100644 --- a/paddle/ir/core/program.h +++ b/paddle/pir/core/program.h @@ -18,14 +18,14 @@ #include #include -#include "paddle/ir/core/attribute.h" -#include "paddle/ir/core/block.h" -#include "paddle/ir/core/builtin_attribute.h" -#include "paddle/ir/core/builtin_op.h" -#include "paddle/ir/core/operation.h" -#include "paddle/ir/core/parameter.h" +#include "paddle/pir/core/attribute.h" +#include "paddle/pir/core/block.h" +#include "paddle/pir/core/builtin_attribute.h" +#include "paddle/pir/core/builtin_op.h" +#include "paddle/pir/core/operation.h" +#include "paddle/pir/core/parameter.h" -namespace ir { +namespace pir { class IrContext; /// @@ -75,4 +75,4 @@ class IR_API Program { std::ostream& operator<<(std::ostream& os, const Program& prog); -} // namespace ir +} // namespace pir diff --git a/paddle/ir/core/region.cc b/paddle/pir/core/region.cc similarity index 91% rename from paddle/ir/core/region.cc rename to paddle/pir/core/region.cc index e9fdb91758219..0f02e3d19e04e 100644 --- a/paddle/ir/core/region.cc +++ b/paddle/pir/core/region.cc @@ -12,12 +12,12 @@ // See the License for the specific language governing permissions and // limitations under the License. -#include "paddle/ir/core/region.h" -#include "paddle/ir/core/block.h" -#include "paddle/ir/core/enforce.h" -#include "paddle/ir/core/operation.h" +#include "paddle/pir/core/region.h" +#include "paddle/pir/core/block.h" +#include "paddle/pir/core/enforce.h" +#include "paddle/pir/core/operation.h" -namespace ir { +namespace pir { Region::~Region() { clear(); } void Region::push_back(Block *block) { insert(blocks_.end(), block); } @@ -61,4 +61,4 @@ IrContext *Region::ir_context() const { IR_ENFORCE(parent_, "Region is not attached to a container."); return parent_->ir_context(); } -} // namespace ir +} // namespace pir diff --git a/paddle/ir/core/region.h b/paddle/pir/core/region.h similarity index 96% rename from paddle/ir/core/region.h rename to paddle/pir/core/region.h index cc1c1ab791df5..06272f82a4378 100644 --- a/paddle/ir/core/region.h +++ b/paddle/pir/core/region.h @@ -17,9 +17,9 @@ #include #include -#include "paddle/ir/core/dll_decl.h" +#include "paddle/pir/core/dll_decl.h" -namespace ir { +namespace pir { class Block; class Operation; @@ -68,4 +68,4 @@ class IR_API Region { Operation *parent_{nullptr}; // not owned std::list blocks_; // owned }; -} // namespace ir +} // namespace pir diff --git a/paddle/ir/core/spin_lock.h b/paddle/pir/core/spin_lock.h similarity index 97% rename from paddle/ir/core/spin_lock.h rename to paddle/pir/core/spin_lock.h index 4150f419c3159..5cba96823a817 100644 --- a/paddle/ir/core/spin_lock.h +++ b/paddle/pir/core/spin_lock.h @@ -23,7 +23,7 @@ #include #include -namespace ir { +namespace pir { static inline void CpuRelax() { #if defined(__PADDLE_x86__) _mm_pause(); @@ -63,4 +63,4 @@ class SpinLock { std::atomic mlock_; }; -} // namespace ir +} // namespace pir diff --git a/paddle/ir/core/storage_manager.cc b/paddle/pir/core/storage_manager.cc similarity index 86% rename from paddle/ir/core/storage_manager.cc rename to paddle/pir/core/storage_manager.cc index 0dcc7ca0ad855..07cc4e07cce2c 100644 --- a/paddle/ir/core/storage_manager.cc +++ b/paddle/pir/core/storage_manager.cc @@ -12,14 +12,14 @@ // See the License for the specific language governing permissions and // limitations under the License. -#include "paddle/ir/core/storage_manager.h" +#include "paddle/pir/core/storage_manager.h" #include #include -#include "paddle/ir/core/enforce.h" +#include "paddle/pir/core/enforce.h" -namespace ir { +namespace pir { // This is a structure for creating, caching, and looking up Storage of // parametric types. struct ParametricStorageManager { @@ -75,9 +75,9 @@ StorageManager::StorageBase *StorageManager::GetParametricStorageImpl( std::size_t hash_value, std::function equal_func, std::function constructor) { - std::lock_guard guard(parametric_instance_lock_); + std::lock_guard guard(parametric_instance_lock_); VLOG(6) << "Try to get a parametric storage of: [TypeId_hash=" - << std::hash()(type_id) << ", param_hash=" << hash_value + << std::hash()(type_id) << ", param_hash=" << hash_value << "]."; if (parametric_instance_.find(type_id) == parametric_instance_.end()) { IR_THROW("The input data pointer is null."); @@ -88,9 +88,9 @@ StorageManager::StorageBase *StorageManager::GetParametricStorageImpl( StorageManager::StorageBase *StorageManager::GetParameterlessStorageImpl( TypeId type_id) { - std::lock_guard guard(parameterless_instance_lock_); + std::lock_guard guard(parameterless_instance_lock_); VLOG(6) << "Try to get a parameterless storage of: [TypeId_hash=" - << std::hash()(type_id) << "]."; + << std::hash()(type_id) << "]."; if (parameterless_instance_.find(type_id) == parameterless_instance_.end()) IR_THROW("TypeId not found in IrContext."); StorageBase *parameterless_instance = parameterless_instance_[type_id]; @@ -99,21 +99,21 @@ StorageManager::StorageBase *StorageManager::GetParameterlessStorageImpl( void StorageManager::RegisterParametricStorageImpl( TypeId type_id, std::function destroy) { - std::lock_guard guard(parametric_instance_lock_); + std::lock_guard guard(parametric_instance_lock_); VLOG(6) << "Register a parametric storage of: [TypeId_hash=" - << std::hash()(type_id) << "]."; + << std::hash()(type_id) << "]."; parametric_instance_.emplace( type_id, std::make_unique(destroy)); } void StorageManager::RegisterParameterlessStorageImpl( TypeId type_id, std::function constructor) { - std::lock_guard guard(parameterless_instance_lock_); + std::lock_guard guard(parameterless_instance_lock_); VLOG(6) << "Register a parameterless storage of: [TypeId_hash=" - << std::hash()(type_id) << "]."; + << std::hash()(type_id) << "]."; if (parameterless_instance_.find(type_id) != parameterless_instance_.end()) IR_THROW("storage class already registered"); parameterless_instance_.emplace(type_id, constructor()); } -} // namespace ir +} // namespace pir diff --git a/paddle/ir/core/storage_manager.h b/paddle/pir/core/storage_manager.h similarity index 96% rename from paddle/ir/core/storage_manager.h rename to paddle/pir/core/storage_manager.h index f2cda194ce215..1853207f5953f 100644 --- a/paddle/ir/core/storage_manager.h +++ b/paddle/pir/core/storage_manager.h @@ -18,10 +18,10 @@ #include #include -#include "paddle/ir/core/spin_lock.h" -#include "paddle/ir/core/type_id.h" +#include "paddle/pir/core/spin_lock.h" +#include "paddle/pir/core/type_id.h" -namespace ir { +namespace pir { /// /// \brief The implementation of the class StorageManager. /// @@ -141,12 +141,12 @@ class IR_API StorageManager { std::unordered_map> parametric_instance_; - ir::SpinLock parametric_instance_lock_; + pir::SpinLock parametric_instance_lock_; // This map is a mapping between type id and parameterless type storage. std::unordered_map parameterless_instance_; - ir::SpinLock parameterless_instance_lock_; + pir::SpinLock parameterless_instance_lock_; }; -} // namespace ir +} // namespace pir diff --git a/paddle/pir/core/storage_manager_support.h b/paddle/pir/core/storage_manager_support.h new file mode 100644 index 0000000000000..a54e066a0e2a6 --- /dev/null +++ b/paddle/pir/core/storage_manager_support.h @@ -0,0 +1,106 @@ +// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include "paddle/pir/core/interface_support.h" +#include "paddle/pir/core/ir_context.h" +#include "paddle/pir/core/type.h" +#include "paddle/pir/core/type_base.h" +#include "paddle/pir/core/type_id.h" + +namespace pir { +template +class TypeInterfaceBase; + +namespace detail { + +namespace storage_helper_base_impl { +/// Returns true if this given Trait ID matches the IDs of any of the provided +/// trait types `Traits`. +template