diff --git a/paddle/CMakeLists.txt b/paddle/CMakeLists.txt index 531324c250c2e2..22eac537766c4d 100644 --- a/paddle/CMakeLists.txt +++ b/paddle/CMakeLists.txt @@ -9,8 +9,6 @@ add_subdirectory(testing) add_subdirectory(phi) add_subdirectory(fluid) -add_subdirectory(pass) - # NOTE(zhiqiu): The changes of cc tests # Before, (1) the source file of cc tests are distributed in different sub-directories, # (2) the tests are added and configured by calling `cc_test()` in each `CMakeLists.txt`, diff --git a/paddle/ir/CMakeLists.txt b/paddle/ir/CMakeLists.txt index 14c08067b5f356..54d93c6bd20cae 100644 --- a/paddle/ir/CMakeLists.txt +++ b/paddle/ir/CMakeLists.txt @@ -3,3 +3,4 @@ if(NOT WITH_NEWIR) endif() add_subdirectory(core) +add_subdirectory(pass) diff --git a/paddle/ir/core/type_name.h b/paddle/ir/core/type_name.h new file mode 100644 index 00000000000000..59d8c8bebb6b15 --- /dev/null +++ b/paddle/ir/core/type_name.h @@ -0,0 +1,59 @@ +// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include +#include + +namespace ir { + +template +inline std::string get_type_name() { +#if defined(__clang__) || defined(__GNUC__) + std::string name = __PRETTY_FUNCTION__; + std::string key = "DesiredTypeName = "; + name = name.substr(name.find(key)); + assert(!name.empty() && "Unable to find the template parameter!"); + name = name.substr(key.size()); + assert(name.back() == "]" && "Name doesn't end in the substitution key!"); + auto sem_pos = name.find_first_of(";"); + if (sem_pos == std::string::npos) + name.pop_back(); + else + name = name.substr(0, sem_pos); + return name; +#elif defined(_MSC_VER) + std::string name = __FUNCSIG__; + std::string key = "get_type_name<"; + name = name.substr(name.find(key)); + assert(!name.empty() && "Unable to find the function name!"); + name = name.substr(key.size()); + for (std::string prefix : {"class ", "struct ", "union ", "enum "}) { + if (name.find(prefix) == 0) { + name = name.substr(prefix.size()); + break; + } + } + auto angle_pos = name.rfind('>'); + assert(angle_pos != std::string::npos && "Unable to find the closing '>'!"); + return name.substr(0, angle_pos); +#else + // No known technique for statically extracting a type name on this compiler. + // We return a string that is unlikely to look like any type in LLVM. + return "UNKNOWN_TYPE"; +#endif +} + +} // namespace ir diff --git a/paddle/pass/CMakeLists.txt b/paddle/ir/pass/CMakeLists.txt similarity index 70% rename from paddle/pass/CMakeLists.txt rename to paddle/ir/pass/CMakeLists.txt index 2a46ae60d2c886..042b3e8e1e9fc4 100644 --- a/paddle/pass/CMakeLists.txt +++ b/paddle/ir/pass/CMakeLists.txt @@ -1,7 +1,3 @@ -if(NOT WITH_NEWIR) - return() -endif() - file(GLOB NEW_PASS_SRCS "*.cc") cc_library( diff --git a/paddle/ir/pass/analysis_manager.h b/paddle/ir/pass/analysis_manager.h new file mode 100644 index 00000000000000..b83ef70e5b4a8b --- /dev/null +++ b/paddle/ir/pass/analysis_manager.h @@ -0,0 +1,308 @@ +// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include +#include +#include +#include +#include + +#include "paddle/ir/core/cast_utils.h" +#include "paddle/ir/core/type_id.h" +#include "paddle/ir/core/type_name.h" +#include "paddle/ir/pass/pass_instrumentation.h" +#include "paddle/ir/pass/utils.h" +#include "paddle/utils/optional.h" + +namespace ir { + +class Operation; +class AnalysisManager; +class PassInstrumentor; + +namespace detail { + +/// A utility class to reprensent the analyses that are kwnown to be preserved. +class PreservedAnalyses { + struct AllAnalysesType {}; + + public: + /// Mark all analyses as preserved. + void PreserveAll() { + preserved_ids_.insert(ir::TypeId::get()); + } + + bool IsAll() const { + return preserved_ids_.count(ir::TypeId::get()); + } + + bool IsNone() const { return preserved_ids_.empty(); } + + template + void Preserve() { + Preserve(ir::TypeId::get()); + } + + template + void Preserve() { + Preserve(); + Preserve(); + } + + void Preserve(ir::TypeId id) { preserved_ids_.insert(id); } + + template + bool IsPreserved() const { + return IsPreserved(ir::TypeId::get()); + } + + bool IsPreserved(ir::TypeId id) const { return preserved_ids_.count(id); } + + template + void Unpreserve() { + preserved_ids_.erase(ir::TypeId::get()); + } + + private: + template + friend struct AnalysisModel; + + std::unordered_set preserved_ids_; +}; + +namespace detail { + +/// Trait to check if T provides a static `IsInvalidated` method. +template +using has_is_invalidated = decltype(std::declval().IsInvalidated( + std::declval())); + +/// Implementation of `IsInvalidated` if the analysis provides a definition. +template +std::enable_if_t::value, bool> +IsInvalidated(AnalysisT& analysis, const PreservedAnalyses& pa) { // NOLINT + return analysis.IsInvalidated(pa); +} + +/// Default implementation of `IsInvalidated`. +template +std::enable_if_t::value, bool> +IsInvalidated(AnalysisT& analysis, const PreservedAnalyses& pa) { // NOLINT + return !pa.IsPreserved(); +} +} // namespace detail + +/// Abstract base class representing an analysis. +struct AnalysisConcept { + virtual ~AnalysisConcept() = default; + + // A hook used to query analyses for invalidation. + virtual bool Invalidate(PreservedAnalyses& pa) = 0; // NOLINT +}; + +template +struct AnalysisModel : public AnalysisConcept { + template + explicit AnalysisModel(Args&&... args) + : analysis(std::forward(args)...) {} + + bool Invalidate(PreservedAnalyses& pa) final { + bool result = detail::IsInvalidated(analysis, pa); + if (result) pa.Unpreserve(); + return result; + } + + AnalysisT analysis; +}; + +/// This class represents a cache of analyses for a single operation. +/// All computation, caching and invalidation of analyses takes place here. +class AnalysisMap { + public: + explicit AnalysisMap(ir::Operation* ir) : ir_(ir) {} + + template + AnalysisT& GetAnalysis(PassInstrumentor* pi, AnalysisManager& am) { + return GetAnalysisImpl(pi, ir_, am); + } + + template + std::enable_if_t< + std::is_constructible::value || + std::is_constructible::value, + AnalysisT&> + GetAnalysis(PassInstrumentor* pi, AnalysisManager& am) { // NOLINT + return GetAnalysisImpl(pi, ir::cast(ir_), am); + } + + template + paddle::optional> GetCachedAnalysis() + const { + auto res = analyses_.find(ir::TypeId::get()); + if (res == analyses_.end()) return paddle::none; + return {static_cast&>(*res->second).analysis}; + } + + ir::Operation* getOperation() const { return ir_; } + + void Clear() { analyses_.clear(); } + + /// Invalidate any cached analyses based upon the given set of preserved + void Invalidate(const PreservedAnalyses& pa) { + PreservedAnalyses pa_copy(pa); + + // Remove any analyses that were invalidaed. + // As using MapVector, order of insertion is preserved and + // dependencies always go before users, so need only one iteration. + for (auto it = analyses_.begin(); it != analyses_.end();) { + if (it->second->Invalidate(pa_copy)) + it = analyses_.erase(it); + else + ++it; + } + } + + private: + template + static std::string GetAnalysisName() { + std::string name = ir::get_type_name(); + auto pos = name.rfind("::"); + if (pos != std::string::npos) { + name = name.substr(pos + 2); + } + return name; + } + + template + AnalysisT& GetAnalysisImpl(PassInstrumentor* pi, + OpT op, + AnalysisManager& am) { // NOLINT + ir::TypeId id = ir::TypeId::get(); + auto it = analyses_.find(id); + if (it == analyses_.end()) { + if (pi) { + pi->RunBeforeAnalysis(GetAnalysisName(), id, ir_); + } + + bool was_inserted; + std::tie(it, was_inserted) = + analyses_.insert({id, ConstructAnalysis(am, op)}); + assert(was_inserted); + + if (pi) { + pi->RunAfterAnalysis(GetAnalysisName(), id, ir_); + } + } + + return static_cast&>(*it->second).analysis; + } + + /// Construct analysis using two arguments constructor (OpT, + /// AnalysisManager&). + template < + typename AnalysisT, + typename OpT, + std::enable_if_t< + std::is_constructible::value>* = + nullptr> + static auto ConstructAnalysis(AnalysisManager& am, OpT op) { // NOLINT + return std::make_unique>(op, am); + } + + /// Construct analysis using single argument constructor (OpT) + template < + typename AnalysisT, + typename OpT, + std::enable_if_t< + !std::is_constructible::value>* = + nullptr> + static auto ConstructAnalysis(AnalysisManager&, OpT op) { + return std::make_unique>(op); + } + + private: + ir::Operation* ir_; + std::unordered_map> analyses_; +}; + +} // namespace detail + +/// This class is intended to be passed around by value, and can not be +/// constructed direcyly. +class AnalysisManager { + public: + using PreservedAnalyses = detail::PreservedAnalyses; + + template + AnalysisT& GetAnalysis() { + return analyses_->GetAnalysis(GetPassInstrumentor(), *this); + } + + template + AnalysisT& GetAnalysis() { + return analyses_->GetAnalysis(GetPassInstrumentor(), *this); + } + + template + paddle::optional> GetCachedAnalysis() + const { + return analyses_->GetCachedAnalysis(); + } + + void Invalidate(const PreservedAnalyses& pa) { + if (pa.IsAll()) return; + + // Invalidate the analyses for the current operation directly. + analyses_->Invalidate(pa); + } + + void clear() { analyses_->Clear(); } + + PassInstrumentor* GetPassInstrumentor() const { return instrumentor_; } + + ir::Operation* GetOperation() { return analyses_->getOperation(); } + + private: + AnalysisManager(detail::AnalysisMap* impl, PassInstrumentor* pi) + : analyses_(impl), instrumentor_(pi) {} + + private: + detail::AnalysisMap* analyses_; + PassInstrumentor* instrumentor_; + + // For access constructor. + friend class AnalysisManagerHolder; +}; + +/// A manager class for the container operation. This class hold the +/// memory for the analyses. AnalysisManager just hold the ref to the +/// analyses. +class AnalysisManagerHolder { + public: + AnalysisManagerHolder(ir::Operation* op, PassInstrumentor* pi) + : analyses_(op), pi_(pi) {} + AnalysisManagerHolder(const AnalysisManagerHolder&) = delete; + AnalysisManagerHolder& operator=(const AnalysisManagerHolder&) = delete; + + /// Returns an analysis manager for the current container op. + operator AnalysisManager() { return AnalysisManager(&analyses_, pi_); } + + private: + detail::AnalysisMap analyses_; + PassInstrumentor* pi_; +}; + +} // namespace ir diff --git a/paddle/ir/pass/pass.cc b/paddle/ir/pass/pass.cc new file mode 100644 index 00000000000000..e4d3124a33248b --- /dev/null +++ b/paddle/ir/pass/pass.cc @@ -0,0 +1,203 @@ +// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/ir/pass/pass.h" +#include "paddle/ir/core/ir_context.h" +#include "paddle/ir/core/operation.h" +#include "paddle/ir/core/program.h" +#include "paddle/ir/pass/pass_adaptor.h" +#include "paddle/ir/pass/pass_instrumentation.h" +#include "paddle/ir/pass/pass_manager.h" + +namespace ir { + +//----------------------------------------------------------------------------------------------// +// PassAdaptor +//----------------------------------------------------------------------------------------------// +void detail::PassAdaptor::Run(ir::Operation* op, + uint8_t opt_level, + bool verify) { + RunImpl(op, opt_level, verify); +} + +void detail::PassAdaptor::RunImpl(ir::Operation* op, + uint8_t opt_level, + bool verify) { + // TODO(liuyuanle): Support block, region, etc. + return; +} + +bool detail::PassAdaptor::RunPipeline(const PassManager& pm, + ir::Operation* op, + AnalysisManager am, + uint8_t opt_level, + bool verify) { + auto* instrumentor = am.GetPassInstrumentor(); + if (instrumentor) { + instrumentor->RunBeforePipeline(op); + } + + for (auto& pass : pm.passes()) { + if (pass->CanScheduleOn(op)) { + if (!RunPass(pass.get(), op, am, opt_level, verify)) { + return false; + } + } + } + + if (instrumentor) { + instrumentor->RunAfterPipeline(op); + } + + // Apply pass manager on all nested ir. + if (!RunPass(pm.pass_adaptor_.get(), op, am, opt_level, verify)) { + return false; + } + + return true; +} + +bool detail::PassAdaptor::RunPass(Pass* pass, + ir::Operation* op, + AnalysisManager am, + uint8_t opt_level, + bool verify) { + if (opt_level < pass->pass_info().opt_level) return true; + + pass->pass_state_ = PassExecutionState(op, am); + + PassInstrumentor* instrumentor = am.GetPassInstrumentor(); + + if (auto* adaptor = dynamic_cast(pass)) { + adaptor->Run(op, opt_level, verify); + } else { + if (instrumentor) instrumentor->RunBeforePass(pass, op); + pass->Run(op); + if (instrumentor) instrumentor->RunAfterPass(pass, op); + } + + bool pass_failed = pass->pass_state().pass_failed; + + // TODO(liuyuanle): Support verification of operation + if (!pass_failed && verify) { + // bool verify_recursively = !dynamic_cast(pass); + // pass_failed = ir::verify(op, verify_recursively); + } + + return !pass_failed; +} + +//----------------------------------------------------------------------------------------------// +// PassManager +//----------------------------------------------------------------------------------------------// +PassManager::PassManager(ir::IrContext* context, uint8_t opt_level) + : context_(context), opt_level_(opt_level) { + pass_adaptor_ = std::make_unique(this); +} + +// bool PassManager::Run(ir::Program* program) const { +// if (!Initialize(context_)) { +// return false; +// } +// return Run(program->operation()); +// } + +bool PassManager::Run(ir::Operation* op) const { + // Construct a analysis manager for the pipeline. + AnalysisManagerHolder am(op, instrumentor_.get()); + + return detail::PassAdaptor::RunPipeline(*this, op, am, opt_level_, verify_); +} + +bool PassManager::Initialize(ir::IrContext* context) const { + for (auto& pass : passes()) { + if (!pass->Initialize(context)) return false; + } + + return true; +} + +void PassManager::AddInstrumentation(std::unique_ptr pi) { + if (!instrumentor_) instrumentor_ = std::make_unique(); + + instrumentor_->AddInstrumentation(std::move(pi)); +} + +//----------------------------------------------------------------------------------------------// +// PassInstrumentor +//----------------------------------------------------------------------------------------------// +namespace detail { +struct PassInstrumentorImpl { + // TODO(wilber): Support multi-thread. + std::vector> instrumentations; +}; +} // namespace detail + +PassInstrumentor::PassInstrumentor() + : impl_(new detail::PassInstrumentorImpl{}) {} + +PassInstrumentor::~PassInstrumentor() = default; + +void PassInstrumentor::RunBeforePipeline(ir::Operation* op) { + for (auto& instr : impl_->instrumentations) { + instr->RunBeforePipeline(op); + } +} + +void PassInstrumentor::RunAfterPipeline(ir::Operation* op) { + for (auto it = impl_->instrumentations.rbegin(); + it != impl_->instrumentations.rend(); + ++it) { + (*it)->RunAfterPipeline(op); + } +} + +void PassInstrumentor::RunBeforePass(Pass* pass, ir::Operation* op) { + for (auto& instr : impl_->instrumentations) { + instr->RunBeforePass(pass, op); + } +} + +void PassInstrumentor::RunAfterPass(Pass* pass, ir::Operation* op) { + for (auto it = impl_->instrumentations.rbegin(); + it != impl_->instrumentations.rend(); + ++it) { + (*it)->RunAfterPass(pass, op); + } +} + +void PassInstrumentor::RunBeforeAnalysis(const std::string& name, + ir::TypeId id, + ir::Operation* op) { + for (auto& instr : impl_->instrumentations) { + instr->RunBeforeAnalysis(name, id, op); + } +} + +void PassInstrumentor::RunAfterAnalysis(const std::string& name, + ir::TypeId id, + ir::Operation* op) { + for (auto it = impl_->instrumentations.rbegin(); + it != impl_->instrumentations.rend(); + ++it) { + (*it)->RunBeforeAnalysis(name, id, op); + } +} + +void PassInstrumentor::AddInstrumentation( + std::unique_ptr pi) { + impl_->instrumentations.emplace_back(std::move(pi)); +} + +} // namespace ir diff --git a/paddle/pass/pass.h b/paddle/ir/pass/pass.h similarity index 71% rename from paddle/pass/pass.h rename to paddle/ir/pass/pass.h index bcbabe6b36efb8..504c081111ac8a 100644 --- a/paddle/pass/pass.h +++ b/paddle/ir/pass/pass.h @@ -17,6 +17,8 @@ #include #include +#include "paddle/ir/pass/analysis_manager.h" +#include "paddle/phi/core/enforce.h" #include "paddle/utils/optional.h" namespace ir { @@ -31,12 +33,15 @@ class PassAdaptor; namespace detail { struct PassExecutionState { - explicit PassExecutionState(ir::Operation* ir) : ir(ir), pass_failed(false) {} + explicit PassExecutionState(ir::Operation* ir, const AnalysisManager& am) + : ir(ir), pass_failed(false), am(am) {} + // The IR currently being processed by pass. ir::Operation* ir; + bool pass_failed; - // TODO(liuyuanle): Add implementation of AnalysisManager and - // PreservedAnalyses. + AnalysisManager am; + PreservedAnalyses preserved_analyses; }; struct PassInfo { @@ -51,7 +56,7 @@ struct PassInfo { // opt_level=0: the basic pass which framework need. // opt_level=1: the fusion logical pass. // opt_level=2: constant fold, cse, memory optimize, etc. - // opt_level=3: layout. + // opt_level=3: layout, etc. uint8_t opt_level; // The list which pass depends on. @@ -67,11 +72,11 @@ class Pass { explicit Pass(const char* name, uint8_t opt_level, const std::vector& dependents = {}) - : info_(name, opt_level, dependents) {} + : pass_info_(name, opt_level, dependents) {} virtual ~Pass() = default; - const detail::PassInfo& GetPassInfo() const { return info_; } + const detail::PassInfo& pass_info() const { return pass_info_; } protected: virtual void Run(ir::Operation* op) = 0; @@ -81,9 +86,19 @@ class Pass { virtual bool Initialize(ir::IrContext* context) { return true; } - void SignalPassFailure() { pass_state_->pass_failed = true; } + AnalysisManager analysis_manager() { return pass_state().am; } + + detail::PassExecutionState& pass_state() { + PADDLE_ENFORCE_EQ(pass_state_.is_initialized(), + true, + phi::errors::Fatal("pass state was never initialized")); + return *pass_state_; + } + + void SignalPassFailure() { pass_state().pass_failed = true; } - detail::PassInfo info_; + private: + detail::PassInfo pass_info_; paddle::optional pass_state_; diff --git a/paddle/pass/pass_adaptor.h b/paddle/ir/pass/pass_adaptor.h similarity index 71% rename from paddle/pass/pass_adaptor.h rename to paddle/ir/pass/pass_adaptor.h index 2bc82510617aae..a580f9435febd4 100644 --- a/paddle/pass/pass_adaptor.h +++ b/paddle/ir/pass/pass_adaptor.h @@ -14,7 +14,7 @@ #pragma once -#include "paddle/pass/pass.h" +#include "paddle/ir/pass/pass.h" namespace ir { @@ -30,16 +30,22 @@ class PassAdaptor final : public Pass { void Run(ir::Operation*) override {} - void Run(ir::Operation*, uint8_t opt_level); + void Run(ir::Operation*, uint8_t opt_level, bool verify); private: - void RunImpl(ir::Operation* op, uint8_t opt_level); + void RunImpl(ir::Operation* op, uint8_t opt_level, bool verify); - static bool RunPass(Pass* pass, ir::Operation* op, uint8_t opt_level); + static bool RunPass(Pass* pass, + ir::Operation* op, + AnalysisManager am, + uint8_t opt_level, + bool verify); static bool RunPipeline(const PassManager& pm, ir::Operation* op, - uint8_t opt_level); + AnalysisManager am, + uint8_t opt_level, + bool verify); // Use for RunImpl later. PassManager* pm_; diff --git a/paddle/ir/pass/pass_instrumentation.h b/paddle/ir/pass/pass_instrumentation.h new file mode 100644 index 00000000000000..d9e10eb6c58c76 --- /dev/null +++ b/paddle/ir/pass/pass_instrumentation.h @@ -0,0 +1,86 @@ +// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include + +#include "paddle/ir/core/type_id.h" + +namespace ir { + +class Operation; +class Pass; + +namespace detail { +struct PassInstrumentorImpl; +} // namespace detail + +class PassInstrumentation { + public: + PassInstrumentation() = default; + virtual ~PassInstrumentation() = default; + + /// A callback to run before a pass pipeline is executed. + virtual void RunBeforePipeline(ir::Operation* op) {} + + virtual void RunAfterPipeline(ir::Operation* op) {} + + virtual void RunBeforePass(Pass* pass, ir::Operation* op) {} + + virtual void RunAfterPass(Pass* pass, ir::Operation* op) {} + + virtual void RunBeforeAnalysis(const std::string& name, + ir::TypeId id, + ir::Operation* op) {} + + virtual void RunAfterAnalysis(const std::string& name, + ir::TypeId id, + ir::Operation* op) {} +}; + +/// This class holds a collection of PassInstrumentation obejcts, and invokes +/// their respective callbacks. +class PassInstrumentor { + public: + PassInstrumentor(); + ~PassInstrumentor(); + PassInstrumentor(PassInstrumentor&&) = delete; + PassInstrumentor(const PassInstrumentor&) = delete; + + void AddInstrumentation(std::unique_ptr pi); + + void RunBeforePipeline(ir::Operation* op); + + void RunAfterPipeline(ir::Operation* op); + + void RunBeforePass(Pass* pass, ir::Operation* op); + + void RunAfterPass(Pass* pass, ir::Operation* op); + + void RunBeforeAnalysis(const std::string& name, + ir::TypeId id /* */, + ir::Operation* op); + + void RunAfterAnalysis(const std::string& name, + ir::TypeId id, + ir::Operation* op); + + // TODO(wilber): Add other hooks. + + private: + std::unique_ptr impl_; +}; + +} // namespace ir diff --git a/paddle/pass/pass_manager.h b/paddle/ir/pass/pass_manager.h similarity index 68% rename from paddle/pass/pass_manager.h rename to paddle/ir/pass/pass_manager.h index 3969c65d264510..0d7706dee37d30 100644 --- a/paddle/pass/pass_manager.h +++ b/paddle/ir/pass/pass_manager.h @@ -14,14 +14,20 @@ #pragma once +#include #include #include +#include "paddle/ir/core/program.h" + namespace ir { class IrContext; class Operation; +class Program; class Pass; +class PassInstrumentation; +class PassInstrumentor; namespace detail { class PassAdaptor; @@ -33,34 +39,37 @@ class PassManager { ~PassManager() = default; - const std::vector> &GetPasses() const { - return passes_; - } + const std::vector> &passes() const { return passes_; } - bool Empty() const { return passes_.empty(); } + bool empty() const { return passes_.empty(); } - ir::IrContext *GetContext() const { return context_; } + ir::IrContext *context() const { return context_; } - bool Run(ir::Operation *op); + // bool Run(ir::Program *program) const; + bool Run(ir::Operation *op) const; void AddPass(std::unique_ptr pass) { passes_.emplace_back(std::move(pass)); } - private: - bool RunPasses(ir::Operation *op); + void AddInstrumentation(std::unique_ptr pi); - bool Initialize(ir::IrContext *context); + private: + bool Initialize(ir::IrContext *context) const; private: ir::IrContext *context_; uint8_t opt_level_; + bool verify_{true}; + std::vector> passes_; std::unique_ptr pass_adaptor_; + std::unique_ptr instrumentor_; + friend class detail::PassAdaptor; }; diff --git a/paddle/ir/pass/utils.h b/paddle/ir/pass/utils.h new file mode 100644 index 00000000000000..b3724431d11e91 --- /dev/null +++ b/paddle/ir/pass/utils.h @@ -0,0 +1,45 @@ +// paddle/pass/utils.h + +// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include + +namespace ir { +namespace detail { + +template +struct make_void { + typedef void type; +}; + +template +using void_t = typename make_void::type; +template class Op, class... Args> +struct detector { + using value_t = std::false_type; +}; + +template