Skip to content

Commit

Permalink
[CustomPass] add register_pass api (PaddlePaddle#55511)
Browse files Browse the repository at this point in the history
  • Loading branch information
ronny1996 authored and cqulilujia committed Jul 24, 2023
1 parent aa85183 commit 207b6c4
Show file tree
Hide file tree
Showing 3 changed files with 19 additions and 1 deletion.
15 changes: 14 additions & 1 deletion paddle/fluid/framework/ir/pass.cc
Original file line number Diff line number Diff line change
Expand Up @@ -74,6 +74,16 @@ static const std::vector<std::string> xpu_support_subgraph_passes = {
"xpu_delete_cast_op_pass",
};

static std::vector<std::string> support_subgraph_generate_passes;

void Pass::AddSupportSubgraphPass(const std::string &pass_type) {
if (std::find(support_subgraph_generate_passes.begin(),
support_subgraph_generate_passes.end(),
pass_type) == support_subgraph_generate_passes.end()) {
support_subgraph_generate_passes.push_back(pass_type);
}
}

Graph *Pass::Apply(Graph *graph) const {
VLOG(10) << "start to apply pass " << Type() << " to graph";
CheckPrevPass();
Expand Down Expand Up @@ -117,7 +127,10 @@ Graph *Pass::Apply(Graph *graph) const {
subgraph_passes = support_subgraph_passes;
}
if (graph->IsMainGraph() &&
std::count(subgraph_passes.begin(), subgraph_passes.end(), Type())) {
(std::count(subgraph_passes.begin(), subgraph_passes.end(), Type()) ||
std::count(support_subgraph_generate_passes.begin(),
support_subgraph_generate_passes.end(),
Type()))) {
for (size_t i = 1; i < graph->SubGraphsSize(); i++) {
auto *sub_graph = graph->GetSubGraph(i);
if (!sub_graph->Has(framework::ir::kParamScopeAttr)) {
Expand Down
2 changes: 2 additions & 0 deletions paddle/fluid/framework/ir/pass.h
Original file line number Diff line number Diff line change
Expand Up @@ -168,6 +168,8 @@ class Pass {

virtual bool SupportApplyProgramViaGraph() const { return true; }

static void AddSupportSubgraphPass(const std::string &pass_type);

protected:
virtual void ApplyImpl(Graph *graph UNUSED) const {
PADDLE_THROW(platform::errors::Unimplemented(
Expand Down
3 changes: 3 additions & 0 deletions paddle/fluid/pybind/pybind.cc
Original file line number Diff line number Diff line change
Expand Up @@ -2319,6 +2319,9 @@ All parameter, weight, gradient are variables in Paddle.
auto pass = framework::ir::PassRegistry::Instance().Get(pass_type);
return std::shared_ptr<framework::ir::Pass>(std::move(pass));
});
m.def("register_subgraph_pass", [](const std::string &pass_type) {
framework::ir::Pass::AddSupportSubgraphPass(pass_type);
});

m.def("size_of_dtype", framework::SizeOfType);
py::class_<paddle::platform::ProfilerResult>(m, "_ProfilerResult")
Expand Down

0 comments on commit 207b6c4

Please sign in to comment.