Skip to content

Commit

Permalink
[New-IR] add pass registry (#56729)
Browse files Browse the repository at this point in the history
* add pass registry

* add pass registry macro
  • Loading branch information
zhiqiu authored Aug 29, 2023
1 parent fc1e1b7 commit 9999e84
Show file tree
Hide file tree
Showing 6 changed files with 138 additions and 11 deletions.
15 changes: 5 additions & 10 deletions paddle/fluid/pybind/ir.cc
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@
#include "paddle/ir/core/value.h"
#include "paddle/ir/pass/pass.h"
#include "paddle/ir/pass/pass_manager.h"
#include "paddle/ir/pass/pass_registry.h"
#include "paddle/ir/transforms/dead_code_elimination_pass.h"
#include "paddle/phi/core/enforce.h"
#include "pybind11/stl.h"
Expand All @@ -57,6 +58,8 @@ using paddle::dialect::APIBuilder;
using paddle::dialect::DenseTensorType;
using pybind11::return_value_policy;

USE_PASS(dead_code_elimination);

namespace paddle {
namespace pybind {

Expand Down Expand Up @@ -488,15 +491,6 @@ void BindIrPass(pybind11::module *m) {
[](const Pass &self) { return self.pass_info().dependents; });
}

// TODO(zhiqiu): refine pass registry
std::unique_ptr<Pass> CreatePassByName(std::string name) {
if (name == "DeadCodeEliminationPass") {
return ir::CreateDeadCodeEliminationPass();
} else {
IR_THROW("The %s pass is not registed", name);
}
}

void BindPassManager(pybind11::module *m) {
py::class_<PassManager, std::shared_ptr<PassManager>> pass_manager(
*m,
Expand All @@ -514,7 +508,8 @@ void BindPassManager(pybind11::module *m) {
py::arg("opt_level") = 2)
.def("add_pass",
[](PassManager &self, std::string pass_name) {
self.AddPass(std::move(CreatePassByName(pass_name)));
self.AddPass(
std::move(ir::PassRegistry::Instance().Get(pass_name)));
})
.def("passes",
[](PassManager &self) {
Expand Down
1 change: 1 addition & 0 deletions paddle/ir/pass/pass.h
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@

#include "paddle/ir/core/enforce.h"
#include "paddle/ir/pass/analysis_manager.h"
#include "paddle/ir/pass/pass_registry.h"
#include "paddle/phi/core/enforce.h"

namespace ir {
Expand Down
23 changes: 23 additions & 0 deletions paddle/ir/pass/pass_registry.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#include "paddle/ir/pass/pass_registry.h"

namespace ir {
PassRegistry &PassRegistry::Instance() {
static PassRegistry g_pass_info_map;
return g_pass_info_map;
}

} // namespace ir
104 changes: 104 additions & 0 deletions paddle/ir/pass/pass_registry.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,104 @@
// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#pragma once

#include <functional>
#include <memory>
#include <unordered_map>

#include "paddle/ir/core/enforce.h"
#include "paddle/ir/core/macros.h"
#include "paddle/ir/pass/pass.h"

namespace ir {

class Pass;

using PassCreator = std::function<std::unique_ptr<Pass>()>;

class PassRegistry {
public:
static PassRegistry &Instance();

bool Has(const std::string &pass_type) const {
return pass_map_.find(pass_type) != pass_map_.end();
}

void Insert(const std::string &pass_type, const PassCreator &pass_creator) {
IR_ENFORCE(
Has(pass_type) != true, "Pass %s has been registered.", pass_type);
pass_map_.insert({pass_type, pass_creator});
}

std::unique_ptr<Pass> Get(const std::string &pass_type) const {
IR_ENFORCE(
Has(pass_type) == true, "Pass %s has not been registered.", pass_type);
return pass_map_.at(pass_type)();
}

private:
PassRegistry() = default;
std::unordered_map<std::string, PassCreator> pass_map_;

DISABLE_COPY_AND_ASSIGN(PassRegistry);
};

template <typename PassType>
class PassRegistrar {
public:
// In our design, various kinds of passes,
// have their corresponding registry and registrar. The action of
// registration is in the constructor of a global registrar variable, which
// are not used in the code that calls package framework, and would
// be removed from the generated binary file by the linker. To avoid such
// removal, we add Touch to all registrar classes and make USE_PASS macros to
// call this method. So, as long as the callee code calls USE_PASS, the global
// registrar variable won't be removed by the linker.
void Touch() {}
explicit PassRegistrar(const char *pass_type) {
PassRegistry::Instance().Insert(
pass_type, []() { return std::make_unique<PassType>(); });
}
};

#define STATIC_ASSERT_PASS_GLOBAL_NAMESPACE(uniq_name, msg) \
struct __test_global_namespace_##uniq_name##__ {}; \
static_assert(std::is_same<::__test_global_namespace_##uniq_name##__, \
__test_global_namespace_##uniq_name##__>::value, \
msg)

// Register a new pass that can be applied on the IR.
#define REGISTER_PASS(pass_type, pass_class) \
STATIC_ASSERT_PASS_GLOBAL_NAMESPACE( \
__reg_pass__##pass_type, \
"REGISTER_PASS must be called in global namespace"); \
static ::ir::PassRegistrar<pass_class> __pass_registrar_##pass_type##__( \
#pass_type); \
int TouchPassRegistrar_##pass_type() { \
__pass_registrar_##pass_type##__.Touch(); \
return 0; \
} \
static ::ir::PassRegistrar<pass_class> &__pass_tmp_registrar_##pass_type##__ \
UNUSED = __pass_registrar_##pass_type##__

#define USE_PASS(pass_type) \
STATIC_ASSERT_PASS_GLOBAL_NAMESPACE( \
__use_pass_itself_##pass_type, \
"USE_PASS must be called in global namespace"); \
extern int TouchPassRegistrar_##pass_type(); \
static int use_pass_itself_##pass_type##_ UNUSED = \
TouchPassRegistrar_##pass_type()

} // namespace ir
3 changes: 3 additions & 0 deletions paddle/ir/transforms/dead_code_elimination_pass.cc
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
#include "paddle/ir/core/builtin_op.h"
#include "paddle/ir/core/program.h"
#include "paddle/ir/pass/pass.h"
#include "paddle/ir/pass/pass_registry.h"

namespace {

Expand Down Expand Up @@ -75,3 +76,5 @@ std::unique_ptr<Pass> CreateDeadCodeEliminationPass() {
}

} // namespace ir

REGISTER_PASS(dead_code_elimination, DeadCodeEliminationPass);
3 changes: 2 additions & 1 deletion test/ir/new_ir/test_pass_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,11 +51,12 @@ def test_op(self):
self.assertTrue('pd.uniform' in op_names)
pm = ir.PassManager()
pm.add_pass(
'DeadCodeEliminationPass'
'dead_code_elimination'
) # apply pass to elimitate dead code
pm.run(new_program)
op_names = [op.name() for op in new_program.block().ops]
# print(op_names)
# TODO(zhiqiu): unify the name of pass
self.assertEqual(pm.passes(), ['DeadCodeEliminationPass'])
self.assertFalse(pm.empty())
self.assertTrue(
Expand Down

0 comments on commit 9999e84

Please sign in to comment.