From 207b6c42767370588e16fe4c8e9f9e5b41cd39b2 Mon Sep 17 00:00:00 2001 From: ronnywang Date: Wed, 19 Jul 2023 22:53:27 +0800 Subject: [PATCH] [CustomPass] add register_pass api (#55511) --- paddle/fluid/framework/ir/pass.cc | 15 ++++++++++++++- paddle/fluid/framework/ir/pass.h | 2 ++ paddle/fluid/pybind/pybind.cc | 3 +++ 3 files changed, 19 insertions(+), 1 deletion(-) diff --git a/paddle/fluid/framework/ir/pass.cc b/paddle/fluid/framework/ir/pass.cc index f11f52a0b1cda..f1cd1face3444 100755 --- a/paddle/fluid/framework/ir/pass.cc +++ b/paddle/fluid/framework/ir/pass.cc @@ -74,6 +74,16 @@ static const std::vector xpu_support_subgraph_passes = { "xpu_delete_cast_op_pass", }; +static std::vector support_subgraph_generate_passes; + +void Pass::AddSupportSubgraphPass(const std::string &pass_type) { + if (std::find(support_subgraph_generate_passes.begin(), + support_subgraph_generate_passes.end(), + pass_type) == support_subgraph_generate_passes.end()) { + support_subgraph_generate_passes.push_back(pass_type); + } +} + Graph *Pass::Apply(Graph *graph) const { VLOG(10) << "start to apply pass " << Type() << " to graph"; CheckPrevPass(); @@ -117,7 +127,10 @@ Graph *Pass::Apply(Graph *graph) const { subgraph_passes = support_subgraph_passes; } if (graph->IsMainGraph() && - std::count(subgraph_passes.begin(), subgraph_passes.end(), Type())) { + (std::count(subgraph_passes.begin(), subgraph_passes.end(), Type()) || + std::count(support_subgraph_generate_passes.begin(), + support_subgraph_generate_passes.end(), + Type()))) { for (size_t i = 1; i < graph->SubGraphsSize(); i++) { auto *sub_graph = graph->GetSubGraph(i); if (!sub_graph->Has(framework::ir::kParamScopeAttr)) { diff --git a/paddle/fluid/framework/ir/pass.h b/paddle/fluid/framework/ir/pass.h index 1f59466e1cd80..473890a4b786b 100644 --- a/paddle/fluid/framework/ir/pass.h +++ b/paddle/fluid/framework/ir/pass.h @@ -168,6 +168,8 @@ class Pass { virtual bool SupportApplyProgramViaGraph() const { return true; } + static void AddSupportSubgraphPass(const std::string &pass_type); + protected: virtual void ApplyImpl(Graph *graph UNUSED) const { PADDLE_THROW(platform::errors::Unimplemented( diff --git a/paddle/fluid/pybind/pybind.cc b/paddle/fluid/pybind/pybind.cc index c4036944bc18a..d55cab98b1eba 100644 --- a/paddle/fluid/pybind/pybind.cc +++ b/paddle/fluid/pybind/pybind.cc @@ -2319,6 +2319,9 @@ All parameter, weight, gradient are variables in Paddle. auto pass = framework::ir::PassRegistry::Instance().Get(pass_type); return std::shared_ptr(std::move(pass)); }); + m.def("register_subgraph_pass", [](const std::string &pass_type) { + framework::ir::Pass::AddSupportSubgraphPass(pass_type); + }); m.def("size_of_dtype", framework::SizeOfType); py::class_(m, "_ProfilerResult")