Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[IR&PASS] part 2-3: add PassTiming #54348

Merged
merged 3 commits into from
Jun 6, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 10 additions & 8 deletions paddle/ir/pass/ir_printing.cc
Original file line number Diff line number Diff line change
Expand Up @@ -13,26 +13,24 @@
// limitations under the License.

#include <ostream>
#include <string>
#include <unordered_map>

#include "paddle/ir/core/operation.h"
#include "paddle/ir/pass/pass.h"
#include "paddle/ir/pass/pass_instrumentation.h"
#include "paddle/ir/pass/pass_manager.h"
#include "paddle/ir/pass/utils.h"

namespace ir {

namespace {
void PrintIR(Operation *op, bool print_module, std::ostream &os) {
// Otherwise, check to see if we are not printing at module scope.
if (print_module) {
if (!print_module) {
op->Print(os << "\n");
return;
}

// Otherwise, we are printing at module scope.
os << " ('" << op->name() << "' operation)\n";

// Find the top-level operation.
auto *top_op = op;
while (auto *parent_op = top_op->GetParentOp()) {
Expand All @@ -55,7 +53,9 @@ class IRPrinting : public PassInstrumentation {
}

option_->PrintBeforeIfEnabled(pass, op, [&](std::ostream &os) {
os << "// *** IR Dump Before " << pass->pass_info().name << " ***";
std::string header =
"IRPrinting on " + op->name() + " before " + pass->name() + " pass";
detail::PrintHeader(header, os);
PrintIR(op, option_->EnablePrintModule(), os);
os << "\n\n";
});
Expand All @@ -66,8 +66,10 @@ class IRPrinting : public PassInstrumentation {
// TODO(liuyuanle): support print on change
}

option_->PrintBeforeIfEnabled(pass, op, [&](std::ostream &os) {
os << "// *** IR Dump After " << pass->pass_info().name << " ***";
option_->PrintAfterIfEnabled(pass, op, [&](std::ostream &os) {
std::string header =
"IRPrinting on " + op->name() + " after " + pass->name() + " pass";
detail::PrintHeader(header, os);
PrintIR(op, option_->EnablePrintModule(), os);
os << "\n\n";
});
Expand Down
40 changes: 22 additions & 18 deletions paddle/ir/pass/pass.cc
Original file line number Diff line number Diff line change
Expand Up @@ -33,13 +33,11 @@ bool Pass::CanApplyOn(Operation* op) const { return op->num_regions() > 0; }
//----------------------------------------------------------------------------------------------//
// PassAdaptor
//----------------------------------------------------------------------------------------------//
void detail::PassAdaptor::Run(ir::Operation* op,
uint8_t opt_level,
bool verify) {
void detail::PassAdaptor::Run(Operation* op, uint8_t opt_level, bool verify) {
RunImpl(op, opt_level, verify);
}

void detail::PassAdaptor::RunImpl(ir::Operation* op,
void detail::PassAdaptor::RunImpl(Operation* op,
uint8_t opt_level,
bool verify) {
auto last_am = analysis_manager();
Expand All @@ -60,7 +58,7 @@ void detail::PassAdaptor::RunImpl(ir::Operation* op,
}

bool detail::PassAdaptor::RunPipeline(const PassManager& pm,
ir::Operation* op,
Operation* op,
AnalysisManager am,
uint8_t opt_level,
bool verify) {
Expand Down Expand Up @@ -90,7 +88,7 @@ bool detail::PassAdaptor::RunPipeline(const PassManager& pm,
}

bool detail::PassAdaptor::RunPass(Pass* pass,
ir::Operation* op,
Operation* op,
AnalysisManager am,
uint8_t opt_level,
bool verify) {
Expand Down Expand Up @@ -122,26 +120,26 @@ bool detail::PassAdaptor::RunPass(Pass* pass,
//----------------------------------------------------------------------------------------------//
// PassManager
//----------------------------------------------------------------------------------------------//
PassManager::PassManager(ir::IrContext* context, uint8_t opt_level)
PassManager::PassManager(IrContext* context, uint8_t opt_level)
: context_(context), opt_level_(opt_level) {
pass_adaptor_ = std::make_unique<detail::PassAdaptor>(this);
}

bool PassManager::Run(ir::Program* program) {
bool PassManager::Run(Program* program) {
if (!Initialize(context_)) {
return false;
}
return Run(program->module_op());
}

bool PassManager::Run(ir::Operation* op) {
bool PassManager::Run(Operation* op) {
// Construct a analysis manager for the pipeline.
AnalysisManagerHolder am(op, instrumentor_.get());

return detail::PassAdaptor::RunPipeline(*this, op, am, opt_level_, verify_);
}

bool PassManager::Initialize(ir::IrContext* context) {
bool PassManager::Initialize(IrContext* context) {
for (auto& pass : passes()) {
if (!pass->Initialize(context)) return false;
}
Expand Down Expand Up @@ -170,27 +168,31 @@ PassInstrumentor::PassInstrumentor()

PassInstrumentor::~PassInstrumentor() = default;

void PassInstrumentor::RunBeforePipeline(ir::Operation* op) {
void PassInstrumentor::RunBeforePipeline(Operation* op) {
if (op->num_regions() == 0) return;
for (auto& instr : impl_->instrumentations) {
instr->RunBeforePipeline(op);
}
}

void PassInstrumentor::RunAfterPipeline(ir::Operation* op) {
void PassInstrumentor::RunAfterPipeline(Operation* op) {
if (op->num_regions() == 0) return;
for (auto it = impl_->instrumentations.rbegin();
it != impl_->instrumentations.rend();
++it) {
(*it)->RunAfterPipeline(op);
}
}

void PassInstrumentor::RunBeforePass(Pass* pass, ir::Operation* op) {
void PassInstrumentor::RunBeforePass(Pass* pass, Operation* op) {
if (op->num_regions() == 0) return;
for (auto& instr : impl_->instrumentations) {
instr->RunBeforePass(pass, op);
}
}

void PassInstrumentor::RunAfterPass(Pass* pass, ir::Operation* op) {
void PassInstrumentor::RunAfterPass(Pass* pass, Operation* op) {
if (op->num_regions() == 0) return;
for (auto it = impl_->instrumentations.rbegin();
it != impl_->instrumentations.rend();
++it) {
Expand All @@ -199,16 +201,18 @@ void PassInstrumentor::RunAfterPass(Pass* pass, ir::Operation* op) {
}

void PassInstrumentor::RunBeforeAnalysis(const std::string& name,
ir::TypeId id,
ir::Operation* op) {
TypeId id,
Operation* op) {
if (op->num_regions() == 0) return;
for (auto& instr : impl_->instrumentations) {
instr->RunBeforeAnalysis(name, id, op);
}
}

void PassInstrumentor::RunAfterAnalysis(const std::string& name,
ir::TypeId id,
ir::Operation* op) {
TypeId id,
Operation* op) {
if (op->num_regions() == 0) return;
for (auto it = impl_->instrumentations.rbegin();
it != impl_->instrumentations.rend();
++it) {
Expand Down
2 changes: 1 addition & 1 deletion paddle/ir/pass/pass.h
Original file line number Diff line number Diff line change
Expand Up @@ -77,7 +77,7 @@ class Pass {

virtual ~Pass();

std::string name() { return pass_info().name; }
std::string name() const { return pass_info().name; }

const detail::PassInfo& pass_info() const { return pass_info_; }

Expand Down
2 changes: 1 addition & 1 deletion paddle/ir/pass/pass_adaptor.h
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@ class PassAdaptor final : public Pass {
uint8_t opt_level,
bool verify);

// Use for RunImpl later.
private:
PassManager* pm_;

// For accessing RunPipeline.
Expand Down
7 changes: 6 additions & 1 deletion paddle/ir/pass/pass_instrumentation.h
Original file line number Diff line number Diff line change
Expand Up @@ -32,19 +32,24 @@ class PassInstrumentation {
PassInstrumentation() = default;
virtual ~PassInstrumentation() = default;

/// A callback to run before a pass pipeline is executed.
// A callback to run before a pass pipeline is executed.
virtual void RunBeforePipeline(Operation* op) {}

// A callback to run after a pass pipeline is executed.
virtual void RunAfterPipeline(Operation* op) {}

// A callback to run before a pass is executed.
virtual void RunBeforePass(Pass* pass, Operation* op) {}

// A callback to run after a pass is executed.
virtual void RunAfterPass(Pass* pass, Operation* op) {}

// A callback to run before a analysis is executed.
virtual void RunBeforeAnalysis(const std::string& name,
TypeId id,
Operation* op) {}

// A callback to run after a analysis is executed.
virtual void RunAfterAnalysis(const std::string& name,
TypeId id,
Operation* op) {}
Expand Down
4 changes: 3 additions & 1 deletion paddle/ir/pass/pass_manager.h
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@ class PassManager {

const std::vector<std::unique_ptr<Pass>> &passes() const { return passes_; }

bool empty() const { return passes_.empty(); }
bool Empty() const { return passes_.empty(); }

IrContext *context() const { return context_; }

Expand Down Expand Up @@ -111,6 +111,8 @@ class PassManager {

void EnableIRPrinting(std::unique_ptr<IRPrinterOption> config);

void EnablePassTiming(bool print_module = true);

void AddInstrumentation(std::unique_ptr<PassInstrumentation> pi);

private:
Expand Down
123 changes: 123 additions & 0 deletions paddle/ir/pass/pass_timing.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,123 @@
// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#include <chrono>
#include <iomanip>
#include <ostream>
#include <string>
#include <unordered_map>

#include "paddle/ir/core/operation.h"
#include "paddle/ir/pass/pass.h"
#include "paddle/ir/pass/pass_instrumentation.h"
#include "paddle/ir/pass/pass_manager.h"
#include "paddle/ir/pass/utils.h"
namespace ir {
namespace {
class Timer {
public:
Timer() = default;

~Timer() = default;

void Start() { start_time_ = std::chrono::steady_clock::now(); }

void Stop() { walk_time += std::chrono::steady_clock::now() - start_time_; }

double GetTimePerSecond() const {
return std::chrono::duration_cast<std::chrono::duration<double>>(walk_time)
.count();
}

private:
std::chrono::time_point<std::chrono::steady_clock> start_time_;

std::chrono::nanoseconds walk_time = std::chrono::nanoseconds(0);
};
} // namespace

class PassTimer : public PassInstrumentation {
public:
explicit PassTimer(bool print_module) : print_module_(print_module) {}
~PassTimer() = default;

void RunBeforePipeline(ir::Operation* op) override {
pipeline_timers_[op] = Timer();
pipeline_timers_[op].Start();
}

void RunAfterPipeline(Operation* op) override {
pipeline_timers_[op].Stop();
PrintTime(op, std::cout);
}

void RunBeforePass(Pass* pass, Operation* op) override {
if (!pass_timers_.count(op)) {
pass_timers_[op] = {};
}
pass_timers_[op][pass->name()] = Timer();
pass_timers_[op][pass->name()].Start();
}

void RunAfterPass(Pass* pass, Operation* op) override {
pass_timers_[op][pass->name()].Stop();
}

private:
void PrintTime(Operation* op, std::ostream& os) {
if (print_module_ && op->name() != "builtin.module") return;

std::string header = "PassTiming on " + op->name();
detail::PrintHeader(header, os);

os << " Total Execution Time: " << std::fixed << std::setprecision(3)
<< pipeline_timers_[op].GetTimePerSecond() << " seconds\n\n";
os << " ----Walk Time---- ----Name----\n";

auto& map = pass_timers_[op];
std::vector<std::pair<std::string, Timer>> pairs(map.begin(), map.end());
std::sort(pairs.begin(),
pairs.end(),
[](const std::pair<std::string, Timer>& lhs,
const std::pair<std::string, Timer>& rhs) {
return lhs.second.GetTimePerSecond() >
rhs.second.GetTimePerSecond();
});

for (auto& v : pairs) {
os << " " << std::fixed << std::setw(8) << std::setprecision(3)
<< v.second.GetTimePerSecond() << " (" << std::setw(5)
<< std::setprecision(1)
<< 100 * v.second.GetTimePerSecond() /
pipeline_timers_[op].GetTimePerSecond()
<< "%)"
<< " " << v.first << "\n";
}
}

private:
bool print_module_;

std::unordered_map<Operation*, Timer> pipeline_timers_;

std::unordered_map<Operation*,
std::unordered_map<std::string /*pass name*/, Timer>>
pass_timers_;
};

void PassManager::EnablePassTiming(bool print_module) {
AddInstrumentation(std::make_unique<PassTimer>(print_module));
}

} // namespace ir
28 changes: 28 additions & 0 deletions paddle/ir/pass/utils.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#include "paddle/ir/pass/utils.h"

namespace ir {
namespace detail {

void PrintHeader(const std::string &header, std::ostream &os) {
unsigned padding = (80 - header.size()) / 2;
os << "===" << std::string(73, '-') << "===\n";
os << std::string(padding, ' ') << header << "\n";
os << "===" << std::string(73, '-') << "===\n";
}

} // namespace detail
} // namespace ir
Loading