Skip to content

Commit

Permalink
Add piano op registry (PaddlePaddle#19)
Browse files Browse the repository at this point in the history
* draft of piano op register, version 0

* fix compile problem, now single test passed

* full single test script

* delete useless __COUNTER__ macro and move file into paddle2piano directory

* change OpDesc to Operand

* realize piano execution context getattr, and add single test script

* optimize make order and optimize single test script

* optimize Instance from pointer to reference

* decoupling op-registry and execution context, move execution context to other PR

* remove useless check macro

* add AddAllowBackendList in PianoOpMaker

* decoupling ExecutionContext, using independent PianoContext instead

* PianoOp just need one kernel, move datatype and layout information into OpRegistration

* merge main branch code and update ElementType to ElementTypeProto

* remove Makefile useless code and add final keyword for BindOp

* remove IsDerived in PianoOpMaker

* optimize op-registry class structure according to CtfGo's advice
  • Loading branch information
thisjiang authored Aug 19, 2021
1 parent c51a344 commit 1755450
Show file tree
Hide file tree
Showing 5 changed files with 579 additions and 0 deletions.
4 changes: 4 additions & 0 deletions paddle/fluid/compiler/paddle2piano/CMakeLists.txt
Original file line number Diff line number Diff line change
@@ -1,2 +1,6 @@

cc_library(piano_compile_pass SRCS piano_compile_pass.cc DEPS pass subgraph_detector)
cc_test(piano_compile_pass_test SRCS piano_compile_pass_tester.cc DEPS piano_compile_pass)

cc_library(piano_op_registry SRCS piano_op_registry.cc DEPS framework_proto note_proto piano_data_description)
cc_test(piano_op_registry_test SRCS piano_op_registry_test.cc DEPS piano_op_registry operator op_registry)
30 changes: 30 additions & 0 deletions paddle/fluid/compiler/paddle2piano/piano_op_kernel.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
/* Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */

#pragma once

namespace paddle {
namespace piano {

class PianoOpKernelContext;

class PianoOpKernel {
public:
virtual void Compile(const PianoOpKernelContext& context) const = 0;

virtual ~PianoOpKernel() = default;
};

} // namespace piano
} // namespace paddle
103 changes: 103 additions & 0 deletions paddle/fluid/compiler/paddle2piano/piano_op_registry.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,103 @@
/* Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */

#include "paddle/fluid/compiler/paddle2piano/piano_op_registry.h"

#include <string>

#include "paddle/fluid/framework/op_info.h"
#include "paddle/fluid/platform/enforce.h"

namespace paddle {
namespace piano {

void PianoOpRegistry::RegisterBackend(
const std::string& backend_name,
const std::unordered_set<note::ElementTypeProto>& supported_types,
BackendFilterFunc filter_func) {
PADDLE_ENFORCE_EQ(
PianoOpRegistry::IsBackend(backend_name), false,
platform::errors::AlreadyExists("Backend %s has been registered.",
backend_name.c_str()));
auto& registry = Instance();
registry.backend_.emplace(backend_name, new Backend);

auto& backend = registry.backend_.at(backend_name);
backend->name = backend_name;
backend->supported_types = supported_types;
backend->filter_func = filter_func;
}

const std::unordered_set<note::ElementTypeProto>&
PianoOpRegistry::BackendDataTypes(const std::string& backend_name) {
PADDLE_ENFORCE_EQ(IsBackend(backend_name), true,
platform::errors::NotFound("Name %s not founded Backend.",
backend_name.c_str()));
return Instance().backend_.at(backend_name)->supported_types;
}

std::vector<std::string> PianoOpRegistry::AllBackendNames() {
auto& registry = Instance();
std::vector<std::string> ret;
for (const auto& backend_pair : registry.backend_) {
ret.emplace_back(backend_pair.first);
}
return ret;
}

bool PianoOpRegistry::HasAllowBackendList(const std::string& op_type) {
PADDLE_ENFORCE_EQ(
IsPianoOp(op_type), true,
platform::errors::NotFound("OP %s is not Piano Op.", op_type.c_str()));
return Instance().ops_.at(op_type)->has_allow_backend_list;
}

std::vector<std::string> PianoOpRegistry::AllPianoOps() {
auto& registry = Instance();
std::vector<std::string> ret;
for (const auto& op_pair : registry.ops_) {
ret.emplace_back(op_pair.first);
}
return ret;
}

const PianoOpRegistry::OpKernelMap& PianoOpRegistry::AllPianoOpKernels(
const std::string& op_type) {
PADDLE_ENFORCE_EQ(
IsPianoOp(op_type), true,
platform::errors::NotFound("OP %s is not Piano Op.", op_type.c_str()));

return Instance().ops_.at(op_type)->kernel_;
}

const framework::AttributeMap& PianoOpRegistry::Attrs(
const std::string& op_type) {
PADDLE_ENFORCE_EQ(
PianoOpRegistry::IsPianoOp(op_type), true,
platform::errors::NotFound("OP %s is not Piano Op.", op_type.c_str()));

return Instance().ops_.at(op_type)->attrs;
}

const std::unordered_set<note::ElementTypeProto>&
PianoOpRegistry::PianoOpDataTypes(const std::string& op_type) {
PADDLE_ENFORCE_EQ(
PianoOpRegistry::IsPianoOp(op_type), true,
platform::errors::NotFound("OP %s is not Piano Op.", op_type.c_str()));

return Instance().ops_.at(op_type)->supported_types;
}

} // namespace piano
} // namespace paddle
Loading

0 comments on commit 1755450

Please sign in to comment.