From 17554507b08dded7c93afcf79ca0bda0e39bd4ff Mon Sep 17 00:00:00 2001 From: jiangcheng Date: Thu, 19 Aug 2021 19:03:47 +0800 Subject: [PATCH] Add piano op registry (#19) * draft of piano op register, version 0 * fix compile problem, now single test passed * full single test script * delete useless __COUNTER__ macro and move file into paddle2piano directory * change OpDesc to Operand * realize piano execution context getattr, and add single test script * optimize make order and optimize single test script * optimize Instance from pointer to reference * decoupling op-registry and execution context, move execution context to other PR * remove useless check macro * add AddAllowBackendList in PianoOpMaker * decoupling ExecutionContext, using independent PianoContext instead * PianoOp just need one kernel, move datatype and layout information into OpRegistration * merge main branch code and update ElementType to ElementTypeProto * remove Makefile useless code and add final keyword for BindOp * remove IsDerived in PianoOpMaker * optimize op-registry class structure according to CtfGo's advice --- .../compiler/paddle2piano/CMakeLists.txt | 4 + .../compiler/paddle2piano/piano_op_kernel.h | 30 ++ .../paddle2piano/piano_op_registry.cc | 103 +++++++ .../compiler/paddle2piano/piano_op_registry.h | 265 ++++++++++++++++++ .../paddle2piano/piano_op_registry_test.cc | 177 ++++++++++++ 5 files changed, 579 insertions(+) create mode 100644 paddle/fluid/compiler/paddle2piano/piano_op_kernel.h create mode 100644 paddle/fluid/compiler/paddle2piano/piano_op_registry.cc create mode 100644 paddle/fluid/compiler/paddle2piano/piano_op_registry.h create mode 100644 paddle/fluid/compiler/paddle2piano/piano_op_registry_test.cc diff --git a/paddle/fluid/compiler/paddle2piano/CMakeLists.txt b/paddle/fluid/compiler/paddle2piano/CMakeLists.txt index e33ea63a55ed8c..822482313c75f1 100644 --- a/paddle/fluid/compiler/paddle2piano/CMakeLists.txt +++ b/paddle/fluid/compiler/paddle2piano/CMakeLists.txt @@ -1,2 +1,6 @@ + cc_library(piano_compile_pass SRCS piano_compile_pass.cc DEPS pass subgraph_detector) cc_test(piano_compile_pass_test SRCS piano_compile_pass_tester.cc DEPS piano_compile_pass) + +cc_library(piano_op_registry SRCS piano_op_registry.cc DEPS framework_proto note_proto piano_data_description) +cc_test(piano_op_registry_test SRCS piano_op_registry_test.cc DEPS piano_op_registry operator op_registry) diff --git a/paddle/fluid/compiler/paddle2piano/piano_op_kernel.h b/paddle/fluid/compiler/paddle2piano/piano_op_kernel.h new file mode 100644 index 00000000000000..a9ac14c62617be --- /dev/null +++ b/paddle/fluid/compiler/paddle2piano/piano_op_kernel.h @@ -0,0 +1,30 @@ +/* Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#pragma once + +namespace paddle { +namespace piano { + +class PianoOpKernelContext; + +class PianoOpKernel { + public: + virtual void Compile(const PianoOpKernelContext& context) const = 0; + + virtual ~PianoOpKernel() = default; +}; + +} // namespace piano +} // namespace paddle diff --git a/paddle/fluid/compiler/paddle2piano/piano_op_registry.cc b/paddle/fluid/compiler/paddle2piano/piano_op_registry.cc new file mode 100644 index 00000000000000..131001736b405e --- /dev/null +++ b/paddle/fluid/compiler/paddle2piano/piano_op_registry.cc @@ -0,0 +1,103 @@ +/* Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include "paddle/fluid/compiler/paddle2piano/piano_op_registry.h" + +#include + +#include "paddle/fluid/framework/op_info.h" +#include "paddle/fluid/platform/enforce.h" + +namespace paddle { +namespace piano { + +void PianoOpRegistry::RegisterBackend( + const std::string& backend_name, + const std::unordered_set& supported_types, + BackendFilterFunc filter_func) { + PADDLE_ENFORCE_EQ( + PianoOpRegistry::IsBackend(backend_name), false, + platform::errors::AlreadyExists("Backend %s has been registered.", + backend_name.c_str())); + auto& registry = Instance(); + registry.backend_.emplace(backend_name, new Backend); + + auto& backend = registry.backend_.at(backend_name); + backend->name = backend_name; + backend->supported_types = supported_types; + backend->filter_func = filter_func; +} + +const std::unordered_set& +PianoOpRegistry::BackendDataTypes(const std::string& backend_name) { + PADDLE_ENFORCE_EQ(IsBackend(backend_name), true, + platform::errors::NotFound("Name %s not founded Backend.", + backend_name.c_str())); + return Instance().backend_.at(backend_name)->supported_types; +} + +std::vector PianoOpRegistry::AllBackendNames() { + auto& registry = Instance(); + std::vector ret; + for (const auto& backend_pair : registry.backend_) { + ret.emplace_back(backend_pair.first); + } + return ret; +} + +bool PianoOpRegistry::HasAllowBackendList(const std::string& op_type) { + PADDLE_ENFORCE_EQ( + IsPianoOp(op_type), true, + platform::errors::NotFound("OP %s is not Piano Op.", op_type.c_str())); + return Instance().ops_.at(op_type)->has_allow_backend_list; +} + +std::vector PianoOpRegistry::AllPianoOps() { + auto& registry = Instance(); + std::vector ret; + for (const auto& op_pair : registry.ops_) { + ret.emplace_back(op_pair.first); + } + return ret; +} + +const PianoOpRegistry::OpKernelMap& PianoOpRegistry::AllPianoOpKernels( + const std::string& op_type) { + PADDLE_ENFORCE_EQ( + IsPianoOp(op_type), true, + platform::errors::NotFound("OP %s is not Piano Op.", op_type.c_str())); + + return Instance().ops_.at(op_type)->kernel_; +} + +const framework::AttributeMap& PianoOpRegistry::Attrs( + const std::string& op_type) { + PADDLE_ENFORCE_EQ( + PianoOpRegistry::IsPianoOp(op_type), true, + platform::errors::NotFound("OP %s is not Piano Op.", op_type.c_str())); + + return Instance().ops_.at(op_type)->attrs; +} + +const std::unordered_set& +PianoOpRegistry::PianoOpDataTypes(const std::string& op_type) { + PADDLE_ENFORCE_EQ( + PianoOpRegistry::IsPianoOp(op_type), true, + platform::errors::NotFound("OP %s is not Piano Op.", op_type.c_str())); + + return Instance().ops_.at(op_type)->supported_types; +} + +} // namespace piano +} // namespace paddle diff --git a/paddle/fluid/compiler/paddle2piano/piano_op_registry.h b/paddle/fluid/compiler/paddle2piano/piano_op_registry.h new file mode 100644 index 00000000000000..bef43862e7d221 --- /dev/null +++ b/paddle/fluid/compiler/paddle2piano/piano_op_registry.h @@ -0,0 +1,265 @@ +/* Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#pragma once + +#include +#include +#include +#include +#include + +#include "paddle/fluid/compiler/piano/note/note.pb.h" +#include "paddle/fluid/compiler/piano/note_builder.h" +#include "paddle/fluid/framework/op_registry.h" +#include "paddle/fluid/framework/type_defs.h" +#include "paddle/fluid/platform/enforce.h" + +namespace paddle { +namespace piano { + +class PianoOpKernelContext; +class PianoOpMaker; + +class PianoOpRegistry final { + public: + using OpKernelFunc = std::function; + using OpKernelMap = std::unordered_map; + + // Register a Piano backend. + // `name` is the backend name. `supported_types` is data type list, + // this backend can only accept the data type in list. `filter_func` is + // a function, return false if the backend refuse this op. + using BackendFilterFunc = bool (*)(Operand*); + static void RegisterBackend( + const std::string& backend_name, + const std::unordered_set& supported_types, + BackendFilterFunc filter_func); + + static inline bool IsBackend(const std::string& backend_name) { + return Instance().backend_.count(backend_name) > 0; + } + + static std::vector AllBackendNames(); + + static const std::unordered_set& BackendDataTypes( + const std::string& backend_name); + + // Piano Op interface + static inline bool IsPianoOp(const std::string& op_type) { + return Instance().ops_.count(op_type) > 0; + } + + static std::vector AllPianoOps(); + + static bool HasAllowBackendList(const std::string& op_type); + + static std::vector AllowBackendList(const std::string& op_type) { + return HasAllowBackendList(op_type) + ? Instance().ops_.at(op_type)->allow_backend_list + : AllBackendNames(); + } + + static const std::unordered_set& PianoOpDataTypes( + const std::string& op_type); + + static const framework::AttributeMap& Attrs(const std::string& op_type); + + static void RegisterKernel(const std::string& op_type, + const std::string& library_type, + OpKernelFunc func) { + // save kernel information into kernel_ map + Instance().ops_.at(op_type)->kernel_.emplace(library_type, func); + } + + static const OpKernelMap& AllPianoOpKernels(const std::string& op_type); + + private: + // Declare PianoOpMaker friend class so that AddAttr can add attribute into + // ops_'s attrs value. + // Why not define an AddAttr function in PianoOpRegistry? Only PianoOpMaker + // can access attribute. + friend class PianoOpMaker; + + // register class + template + friend class PianoOpRegistrar; + + static PianoOpRegistry& Instance() { + static PianoOpRegistry r; + return r; + } + + PianoOpRegistry() = default; + ~PianoOpRegistry() = default; + + DISABLE_COPY_AND_ASSIGN(PianoOpRegistry); + + // Describes a Piano backend + struct Backend { + std::string name; + std::unordered_set supported_types; + + // A filter function used to exclude or modify operator + // registrations on the device. If nullptr, the backend + // accept all op, else it should return false if the op + // cannot register at this backend. + // The function may modify operator to adapt the backend. + BackendFilterFunc filter_func = nullptr; + }; + + // Map from backend name to its descriptor + std::unordered_map> backend_; + + // Describes a Paddle operator that can be compiled to Piano + struct OpRegistration { + std::string op_type; + std::unordered_set supported_types; + + bool has_allow_backend_list = false; + std::vector allow_backend_list; + + // Different to OpProto::attrs, these attribute are only used for + // Piano, which can be obtained at Piano compile time. + framework::AttributeMap attrs; + + std::unique_ptr maker; + + // Piano Op kernel map, the key is library name and its value is a + // kernel function, the kernel function override the "Compile" + // interface of "PianoOpKernel" class. + OpKernelMap kernel_; + }; + + // Map from operator name to its descriptor + std::unordered_map> ops_; +}; + +// just used for mark final keyword +class PianoOpMakerBase { + public: + virtual void BindOp(const std::string& op_type) = 0; + virtual ~PianoOpMakerBase() = default; +}; + +class PianoOpMaker : public PianoOpMakerBase { + public: + virtual void Make() = 0; + + virtual ~PianoOpMaker() = default; + + // Do not rewrite this API in derived class! + void BindOp(const std::string& op_type) final { + this->op_ = PianoOpRegistry::Instance().ops_.at(op_type).get(); + } + + protected: + // cover the old one if existed a same name attribute + // Do not rewrite this API in derived class! + template + void AddAttr(const std::string& name, const T& val) { + op_->attrs.emplace(name, val); + } + + void SetAllowBackendList(const std::vector& backends) { + op_->has_allow_backend_list = true; + op_->allow_backend_list = backends; + } + + void SetDataTypes( + const std::unordered_set& data_types) { + op_->supported_types.insert(data_types.cbegin(), data_types.cend()); + } + + private: + PianoOpRegistry::OpRegistration* op_; +}; + +template +class PianoOpRegistrar final : public framework::Registrar { + public: + PianoOpRegistrar(const char* op_type, const char* library_type) { + using paddle::framework::OpInfoMap; + PADDLE_ENFORCE_EQ(OpInfoMap::Instance().Has(op_type), true, + platform::errors::NotFound( + "Piano OP should registered in Paddle before, " + "but %s not. Please use \"REGISTER_OPERATOR\" " + "before register Piano OP.", + op_type)); + + PADDLE_ENFORCE_EQ(PianoOpRegistry::IsPianoOp(op_type), false, + platform::errors::AlreadyExists( + "Piano OP %s has been registered.", op_type)); + + // bind PianoOpMaker class for add Piano attribute later + // Do need check whether OpMakeType derive from PianoOpMaker? + static_assert(std::is_base_of::value, + "The OpMaker class is not derived from PianoOpMaker."); + + // create and bind OpRegistration + auto& registry = PianoOpRegistry::Instance(); + registry.ops_.emplace(op_type, new PianoOpRegistry::OpRegistration); + auto& op_reg = registry.ops_.at(op_type); + op_reg->op_type = op_type; + + // bind PianoOpMaker class for add Piano attribute later + op_reg->maker.reset(new OpMakeType); + op_reg->maker->BindOp(op_type); + // TODO(jiangcheng05): here invoke Make() is not a good idea + op_reg->maker->Make(); + + PianoOpRegistry::RegisterKernel( + op_type, library_type, + [](const PianoOpKernelContext& ctx) { KernelType().Compile(ctx); }); + } +}; + +#define REGISTER_PIANO_OP(op_type, op_maker, op_kernel) \ + REGISTER_PIANO_OP_EX(op_type, PLAIN, op_maker, op_kernel) + +#define REGISTER_PIANO_OP_EX(TYPE, LIB, MAKER, KERNEL) \ + STATIC_ASSERT_GLOBAL_NAMESPACE( \ + __reg_piano_op__##TYPE##_##LIB, \ + "REGISTER_PIANO_OP must be called in global namespace"); \ + static ::paddle::piano::PianoOpRegistrar \ + __piano_op_registrar__##TYPE##_##LIB##__(#TYPE, #LIB); \ + int TouchPianoOpRegistrar_##TYPE##_##LIB() { \ + __piano_op_registrar__##TYPE##_##LIB##__.Touch(); \ + return 0; \ + } + +class BackendRegistrar final : public framework::Registrar { + public: + BackendRegistrar( + const char* backend_name, + const std::unordered_set& supported_types, + PianoOpRegistry::BackendFilterFunc filter_func = nullptr) { + PianoOpRegistry::RegisterBackend(backend_name, supported_types, + filter_func); + } +}; + +#define REGISTER_PIANO_BACKEND(NAME, ...) \ + STATIC_ASSERT_GLOBAL_NAMESPACE( \ + __reg_piano_backend__##NAME, \ + "REGISTER_PIANO_BACKEND must be called in global namespace"); \ + static ::paddle::piano::BackendRegistrar \ + __piano_backend_registrar__##NAME##__(#NAME, __VA_ARGS__); \ + int TouchBackendRegistrar_##NAME() { \ + __piano_backend_registrar__##NAME##__.Touch(); \ + return 0; \ + } + +} // namespace piano +} // namespace paddle diff --git a/paddle/fluid/compiler/paddle2piano/piano_op_registry_test.cc b/paddle/fluid/compiler/paddle2piano/piano_op_registry_test.cc new file mode 100644 index 00000000000000..f1bd0e3919022d --- /dev/null +++ b/paddle/fluid/compiler/paddle2piano/piano_op_registry_test.cc @@ -0,0 +1,177 @@ +/* Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include "paddle/fluid/compiler/paddle2piano/piano_op_registry.h" + +#include +#include + +#include "glog/logging.h" +#include "gtest/gtest.h" + +#include "paddle/fluid/compiler/paddle2piano/piano_op_kernel.h" +#include "paddle/fluid/framework/op_registry.h" +#include "paddle/fluid/framework/operator.h" + +namespace paddle { +namespace piano { + +using paddle::framework::InferShapeContext; +using paddle::framework::OpProtoAndCheckerMaker; +using paddle::framework::OperatorWithKernel; + +class TestOp : public OperatorWithKernel { + public: + using OperatorWithKernel::OperatorWithKernel; + + void InferShape(InferShapeContext* ctx) const override { + OP_INOUT_CHECK(ctx->HasInput("X"), "Input", "X", "test"); + OP_INOUT_CHECK(ctx->HasOutput("Out"), "Output", "Out", "test"); + + auto in_dims = ctx->GetInputDim("X"); + + ctx->SetOutputDim("Out", in_dims); + ctx->ShareLoD("X", /*->*/ "Out"); + } +}; + +class TestOpMaker : public OpProtoAndCheckerMaker { + public: + void Make() override { + AddInput("X", "(Tensor), The input tensor of test op."); + AddOutput("Out", "(Tensor), The output tensor of test op."); + AddComment(R"DOC( + Test Operator. + + This operator is used to test piano test op registry OK. + + )DOC"); + } +}; + +std::unordered_set TestDatatypes() { + static std::unordered_set supported_types = { + note::F16, note::F32, note::F64}; + return supported_types; +} + +bool TestFilterFunc(Operand* op) { + // TODO(jiangcheng05) : fill some change of Operand + return true; +} +} // namespace piano +} // namespace paddle + +// register paddle op +REGISTER_OP_WITHOUT_GRADIENT(test, paddle::piano::TestOp, + paddle::piano::TestOpMaker); +REGISTER_OP_WITHOUT_GRADIENT(op_not_piano, paddle::piano::TestOp, + paddle::piano::TestOpMaker); +REGISTER_OP_WITHOUT_GRADIENT(test_limit_backend, paddle::piano::TestOp, + paddle::piano::TestOpMaker); + +// register backend which support some datetype but all op +REGISTER_PIANO_BACKEND(PIANO_JIT_TEST_TYPE, paddle::piano::TestDatatypes()) +// register backend which support some datetype and some op +REGISTER_PIANO_BACKEND(PIANO_JIT_TEST_FILTER, paddle::piano::TestDatatypes(), + paddle::piano::TestFilterFunc) + +namespace paddle { +namespace piano { + +class TestPianoOpMaker : public PianoOpMaker { + public: + void Make() override { AddAttr("test_attr", 100); } +}; + +// register piano op kernel with limit allow backend list +class TestPianoOpWithAllowBackendMaker : public PianoOpMaker { + public: + void Make() override { + SetAllowBackendList({"PIANO_JIT_TEST_FILTER"}); + SetDataTypes(TestDatatypes()); + AddAttr("test_attr", 100); + } +}; + +class TestPianoOpKernel : public PianoOpKernel { + public: + void Compile(const PianoOpKernelContext& context) const override { + // do nothing, pass + } +}; + +} // namespace piano +} // namespace paddle + +REGISTER_PIANO_OP(test, paddle::piano::TestPianoOpMaker, + paddle::piano::TestPianoOpKernel) +REGISTER_PIANO_OP(test_limit_backend, + paddle::piano::TestPianoOpWithAllowBackendMaker, + paddle::piano::TestPianoOpKernel) + +namespace paddle { +namespace piano { + +TEST(TestPianoOpRegistry, CheckBackendRegistered) { + ASSERT_FALSE(PianoOpRegistry::IsBackend("BACKEND_NO_REGISTERED")); + ASSERT_TRUE(PianoOpRegistry::IsBackend("PIANO_JIT_TEST_TYPE")); + ASSERT_TRUE(PianoOpRegistry::IsBackend("PIANO_JIT_TEST_FILTER")); + + auto backends = PianoOpRegistry::AllBackendNames(); + std::stable_sort(backends.begin(), backends.end()); + ASSERT_EQ(backends, std::vector( + {"PIANO_JIT_TEST_FILTER", "PIANO_JIT_TEST_TYPE"})); + ASSERT_EQ(PianoOpRegistry::BackendDataTypes("PIANO_JIT_TEST_TYPE"), + TestDatatypes()); +} + +TEST(TestPianoOpRegistry, CheckPianoOpRegistered) { + // check piano register OK + ASSERT_FALSE(PianoOpRegistry::IsPianoOp("op_no_registered")); + ASSERT_FALSE(PianoOpRegistry::IsPianoOp("op_not_piano")); + ASSERT_TRUE(PianoOpRegistry::IsPianoOp("test")); + ASSERT_TRUE(PianoOpRegistry::IsPianoOp("test_limit_backend")); + + // check register store OK + auto ops = PianoOpRegistry::AllPianoOps(); + std::stable_sort(ops.begin(), ops.end()); + ASSERT_EQ(ops, std::vector({"test", "test_limit_backend"})); + + // check piano op's attribute OK + const auto& attrs = PianoOpRegistry::Attrs("test"); + ASSERT_NE(attrs.find("test_attr"), attrs.cend()); + + const auto& attr = attrs.at("test_attr"); + ASSERT_NO_THROW(BOOST_GET_CONST(int, attr)); + ASSERT_EQ(BOOST_GET_CONST(int, attr), 100); + + // check allow backend list OK + ASSERT_FALSE(PianoOpRegistry::HasAllowBackendList("test")); + ASSERT_TRUE(PianoOpRegistry::HasAllowBackendList("test_limit_backend")); + ASSERT_EQ(PianoOpRegistry::AllowBackendList("test_limit_backend"), + std::vector({"PIANO_JIT_TEST_FILTER"})); + + ASSERT_EQ(PianoOpRegistry::PianoOpDataTypes("test_limit_backend"), + TestDatatypes()); +} + +TEST(TestPianoOpRegistry, CheckOpKernelRegistered) { + const auto& kernels = PianoOpRegistry::AllPianoOpKernels("test"); + + ASSERT_FALSE(kernels.empty()); +} + +} // namespace piano +} // namespace paddle