Skip to content

Commit

Permalink
[IR&PASS] part 2-3: add PassTiming (#54348)
Browse files Browse the repository at this point in the history
  • Loading branch information
yuanlehome authored Jun 6, 2023
1 parent 3ea7d57 commit 24c4e37
Show file tree
Hide file tree
Showing 10 changed files with 207 additions and 35 deletions.
18 changes: 10 additions & 8 deletions paddle/ir/pass/ir_printing.cc
Original file line number Diff line number Diff line change
Expand Up @@ -13,26 +13,24 @@
// limitations under the License.

#include <ostream>
#include <string>
#include <unordered_map>

#include "paddle/ir/core/operation.h"
#include "paddle/ir/pass/pass.h"
#include "paddle/ir/pass/pass_instrumentation.h"
#include "paddle/ir/pass/pass_manager.h"
#include "paddle/ir/pass/utils.h"

namespace ir {

namespace {
void PrintIR(Operation *op, bool print_module, std::ostream &os) {
// Otherwise, check to see if we are not printing at module scope.
if (print_module) {
if (!print_module) {
op->Print(os << "\n");
return;
}

// Otherwise, we are printing at module scope.
os << " ('" << op->name() << "' operation)\n";

// Find the top-level operation.
auto *top_op = op;
while (auto *parent_op = top_op->GetParentOp()) {
Expand All @@ -55,7 +53,9 @@ class IRPrinting : public PassInstrumentation {
}

option_->PrintBeforeIfEnabled(pass, op, [&](std::ostream &os) {
os << "// *** IR Dump Before " << pass->pass_info().name << " ***";
std::string header =
"IRPrinting on " + op->name() + " before " + pass->name() + " pass";
detail::PrintHeader(header, os);
PrintIR(op, option_->EnablePrintModule(), os);
os << "\n\n";
});
Expand All @@ -66,8 +66,10 @@ class IRPrinting : public PassInstrumentation {
// TODO(liuyuanle): support print on change
}

option_->PrintBeforeIfEnabled(pass, op, [&](std::ostream &os) {
os << "// *** IR Dump After " << pass->pass_info().name << " ***";
option_->PrintAfterIfEnabled(pass, op, [&](std::ostream &os) {
std::string header =
"IRPrinting on " + op->name() + " after " + pass->name() + " pass";
detail::PrintHeader(header, os);
PrintIR(op, option_->EnablePrintModule(), os);
os << "\n\n";
});
Expand Down
40 changes: 22 additions & 18 deletions paddle/ir/pass/pass.cc
Original file line number Diff line number Diff line change
Expand Up @@ -33,13 +33,11 @@ bool Pass::CanApplyOn(Operation* op) const { return op->num_regions() > 0; }
//----------------------------------------------------------------------------------------------//
// PassAdaptor
//----------------------------------------------------------------------------------------------//
void detail::PassAdaptor::Run(ir::Operation* op,
uint8_t opt_level,
bool verify) {
void detail::PassAdaptor::Run(Operation* op, uint8_t opt_level, bool verify) {
RunImpl(op, opt_level, verify);
}

void detail::PassAdaptor::RunImpl(ir::Operation* op,
void detail::PassAdaptor::RunImpl(Operation* op,
uint8_t opt_level,
bool verify) {
auto last_am = analysis_manager();
Expand All @@ -60,7 +58,7 @@ void detail::PassAdaptor::RunImpl(ir::Operation* op,
}

bool detail::PassAdaptor::RunPipeline(const PassManager& pm,
ir::Operation* op,
Operation* op,
AnalysisManager am,
uint8_t opt_level,
bool verify) {
Expand Down Expand Up @@ -90,7 +88,7 @@ bool detail::PassAdaptor::RunPipeline(const PassManager& pm,
}

bool detail::PassAdaptor::RunPass(Pass* pass,
ir::Operation* op,
Operation* op,
AnalysisManager am,
uint8_t opt_level,
bool verify) {
Expand Down Expand Up @@ -122,26 +120,26 @@ bool detail::PassAdaptor::RunPass(Pass* pass,
//----------------------------------------------------------------------------------------------//
// PassManager
//----------------------------------------------------------------------------------------------//
PassManager::PassManager(ir::IrContext* context, uint8_t opt_level)
PassManager::PassManager(IrContext* context, uint8_t opt_level)
: context_(context), opt_level_(opt_level) {
pass_adaptor_ = std::make_unique<detail::PassAdaptor>(this);
}

bool PassManager::Run(ir::Program* program) {
bool PassManager::Run(Program* program) {
if (!Initialize(context_)) {
return false;
}
return Run(program->module_op());
}

bool PassManager::Run(ir::Operation* op) {
bool PassManager::Run(Operation* op) {
// Construct a analysis manager for the pipeline.
AnalysisManagerHolder am(op, instrumentor_.get());

return detail::PassAdaptor::RunPipeline(*this, op, am, opt_level_, verify_);
}

bool PassManager::Initialize(ir::IrContext* context) {
bool PassManager::Initialize(IrContext* context) {
for (auto& pass : passes()) {
if (!pass->Initialize(context)) return false;
}
Expand Down Expand Up @@ -170,27 +168,31 @@ PassInstrumentor::PassInstrumentor()

PassInstrumentor::~PassInstrumentor() = default;

void PassInstrumentor::RunBeforePipeline(ir::Operation* op) {
void PassInstrumentor::RunBeforePipeline(Operation* op) {
if (op->num_regions() == 0) return;
for (auto& instr : impl_->instrumentations) {
instr->RunBeforePipeline(op);
}
}

void PassInstrumentor::RunAfterPipeline(ir::Operation* op) {
void PassInstrumentor::RunAfterPipeline(Operation* op) {
if (op->num_regions() == 0) return;
for (auto it = impl_->instrumentations.rbegin();
it != impl_->instrumentations.rend();
++it) {
(*it)->RunAfterPipeline(op);
}
}

void PassInstrumentor::RunBeforePass(Pass* pass, ir::Operation* op) {
void PassInstrumentor::RunBeforePass(Pass* pass, Operation* op) {
if (op->num_regions() == 0) return;
for (auto& instr : impl_->instrumentations) {
instr->RunBeforePass(pass, op);
}
}

void PassInstrumentor::RunAfterPass(Pass* pass, ir::Operation* op) {
void PassInstrumentor::RunAfterPass(Pass* pass, Operation* op) {
if (op->num_regions() == 0) return;
for (auto it = impl_->instrumentations.rbegin();
it != impl_->instrumentations.rend();
++it) {
Expand All @@ -199,16 +201,18 @@ void PassInstrumentor::RunAfterPass(Pass* pass, ir::Operation* op) {
}

void PassInstrumentor::RunBeforeAnalysis(const std::string& name,
ir::TypeId id,
ir::Operation* op) {
TypeId id,
Operation* op) {
if (op->num_regions() == 0) return;
for (auto& instr : impl_->instrumentations) {
instr->RunBeforeAnalysis(name, id, op);
}
}

void PassInstrumentor::RunAfterAnalysis(const std::string& name,
ir::TypeId id,
ir::Operation* op) {
TypeId id,
Operation* op) {
if (op->num_regions() == 0) return;
for (auto it = impl_->instrumentations.rbegin();
it != impl_->instrumentations.rend();
++it) {
Expand Down
2 changes: 1 addition & 1 deletion paddle/ir/pass/pass.h
Original file line number Diff line number Diff line change
Expand Up @@ -77,7 +77,7 @@ class Pass {

virtual ~Pass();

std::string name() { return pass_info().name; }
std::string name() const { return pass_info().name; }

const detail::PassInfo& pass_info() const { return pass_info_; }

Expand Down
2 changes: 1 addition & 1 deletion paddle/ir/pass/pass_adaptor.h
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@ class PassAdaptor final : public Pass {
uint8_t opt_level,
bool verify);

// Use for RunImpl later.
private:
PassManager* pm_;

// For accessing RunPipeline.
Expand Down
7 changes: 6 additions & 1 deletion paddle/ir/pass/pass_instrumentation.h
Original file line number Diff line number Diff line change
Expand Up @@ -32,19 +32,24 @@ class PassInstrumentation {
PassInstrumentation() = default;
virtual ~PassInstrumentation() = default;

/// A callback to run before a pass pipeline is executed.
// A callback to run before a pass pipeline is executed.
virtual void RunBeforePipeline(Operation* op) {}

// A callback to run after a pass pipeline is executed.
virtual void RunAfterPipeline(Operation* op) {}

// A callback to run before a pass is executed.
virtual void RunBeforePass(Pass* pass, Operation* op) {}

// A callback to run after a pass is executed.
virtual void RunAfterPass(Pass* pass, Operation* op) {}

// A callback to run before a analysis is executed.
virtual void RunBeforeAnalysis(const std::string& name,
TypeId id,
Operation* op) {}

// A callback to run after a analysis is executed.
virtual void RunAfterAnalysis(const std::string& name,
TypeId id,
Operation* op) {}
Expand Down
4 changes: 3 additions & 1 deletion paddle/ir/pass/pass_manager.h
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@ class PassManager {

const std::vector<std::unique_ptr<Pass>> &passes() const { return passes_; }

bool empty() const { return passes_.empty(); }
bool Empty() const { return passes_.empty(); }

IrContext *context() const { return context_; }

Expand Down Expand Up @@ -111,6 +111,8 @@ class PassManager {

void EnableIRPrinting(std::unique_ptr<IRPrinterOption> config);

void EnablePassTiming(bool print_module = true);

void AddInstrumentation(std::unique_ptr<PassInstrumentation> pi);

private:
Expand Down
123 changes: 123 additions & 0 deletions paddle/ir/pass/pass_timing.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,123 @@
// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#include <chrono>
#include <iomanip>
#include <ostream>
#include <string>
#include <unordered_map>

#include "paddle/ir/core/operation.h"
#include "paddle/ir/pass/pass.h"
#include "paddle/ir/pass/pass_instrumentation.h"
#include "paddle/ir/pass/pass_manager.h"
#include "paddle/ir/pass/utils.h"
namespace ir {
namespace {
class Timer {
public:
Timer() = default;

~Timer() = default;

void Start() { start_time_ = std::chrono::steady_clock::now(); }

void Stop() { walk_time += std::chrono::steady_clock::now() - start_time_; }

double GetTimePerSecond() const {
return std::chrono::duration_cast<std::chrono::duration<double>>(walk_time)
.count();
}

private:
std::chrono::time_point<std::chrono::steady_clock> start_time_;

std::chrono::nanoseconds walk_time = std::chrono::nanoseconds(0);
};
} // namespace

class PassTimer : public PassInstrumentation {
public:
explicit PassTimer(bool print_module) : print_module_(print_module) {}
~PassTimer() = default;

void RunBeforePipeline(ir::Operation* op) override {
pipeline_timers_[op] = Timer();
pipeline_timers_[op].Start();
}

void RunAfterPipeline(Operation* op) override {
pipeline_timers_[op].Stop();
PrintTime(op, std::cout);
}

void RunBeforePass(Pass* pass, Operation* op) override {
if (!pass_timers_.count(op)) {
pass_timers_[op] = {};
}
pass_timers_[op][pass->name()] = Timer();
pass_timers_[op][pass->name()].Start();
}

void RunAfterPass(Pass* pass, Operation* op) override {
pass_timers_[op][pass->name()].Stop();
}

private:
void PrintTime(Operation* op, std::ostream& os) {
if (print_module_ && op->name() != "builtin.module") return;

std::string header = "PassTiming on " + op->name();
detail::PrintHeader(header, os);

os << " Total Execution Time: " << std::fixed << std::setprecision(3)
<< pipeline_timers_[op].GetTimePerSecond() << " seconds\n\n";
os << " ----Walk Time---- ----Name----\n";

auto& map = pass_timers_[op];
std::vector<std::pair<std::string, Timer>> pairs(map.begin(), map.end());
std::sort(pairs.begin(),
pairs.end(),
[](const std::pair<std::string, Timer>& lhs,
const std::pair<std::string, Timer>& rhs) {
return lhs.second.GetTimePerSecond() >
rhs.second.GetTimePerSecond();
});

for (auto& v : pairs) {
os << " " << std::fixed << std::setw(8) << std::setprecision(3)
<< v.second.GetTimePerSecond() << " (" << std::setw(5)
<< std::setprecision(1)
<< 100 * v.second.GetTimePerSecond() /
pipeline_timers_[op].GetTimePerSecond()
<< "%)"
<< " " << v.first << "\n";
}
}

private:
bool print_module_;

std::unordered_map<Operation*, Timer> pipeline_timers_;

std::unordered_map<Operation*,
std::unordered_map<std::string /*pass name*/, Timer>>
pass_timers_;
};

void PassManager::EnablePassTiming(bool print_module) {
AddInstrumentation(std::make_unique<PassTimer>(print_module));
}

} // namespace ir
28 changes: 28 additions & 0 deletions paddle/ir/pass/utils.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#include "paddle/ir/pass/utils.h"

namespace ir {
namespace detail {

void PrintHeader(const std::string &header, std::ostream &os) {
unsigned padding = (80 - header.size()) / 2;
os << "===" << std::string(73, '-') << "===\n";
os << std::string(padding, ' ') << header << "\n";
os << "===" << std::string(73, '-') << "===\n";
}

} // namespace detail
} // namespace ir
Loading

0 comments on commit 24c4e37

Please sign in to comment.