Skip to content

[CINN] Compile time status count #72206

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 5 commits into
base: develop
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
26 changes: 26 additions & 0 deletions paddle/cinn/backends/compiler.cc
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
#include "paddle/cinn/backends/compiler.h"

#include <sys/stat.h>
#include <chrono>
#include <fstream>
#include "paddle/cinn/backends/codegen_cuda_host.h"
#include "paddle/cinn/backends/codegen_device_util.h"
Expand Down Expand Up @@ -48,6 +49,7 @@ PD_DECLARE_string(cinn_dump_group_source_code);
PD_DECLARE_string(cinn_dump_group_ptx);
PD_DECLARE_string(cinn_dump_group_instruction);
PD_DECLARE_string(cinn_debug_custom_code_path);
COMMON_DECLARE_bool(cinn_debug);

namespace {

Expand Down Expand Up @@ -258,7 +260,15 @@ void Compiler::AppendBroadcastSwitchModule(const ir::Module& module) {
}

void Compiler::EndCompile() {
auto start = std::chrono::high_resolution_clock::now();
RegisterDeviceModuleSymbol();
auto end = std::chrono::high_resolution_clock::now();
auto duration =
std::chrono::duration_cast<std::chrono::microseconds>(end - start);
if (FLAGS_cinn_debug) {
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

现在别加呀 😂 现在还没有这个 flag,我是说下周一 #72181 合入后再加

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

看错了...以为是下周一合入前...怪不得编译不过呢 ┭┮﹏┭┮

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

PR 已合入,可以 merge 下最新 develop,另外记得相关 flag 使用前需要 declare 下

LOG(INFO) << "Time of nvrtc compile: ***** [ " << duration.count()
<< " ] ***** microseconds.";
}
engine_->AddSelfModule();
}

Expand Down Expand Up @@ -446,8 +456,16 @@ void Compiler::RegisterSyclModuleSymbol() {
void Compiler::CompileCudaModule(const Module& module,
const std::string& code) {
#ifdef CINN_WITH_CUDA
auto start = std::chrono::high_resolution_clock::now();
auto _host_module_device_module_ =
SplitDeviceAndHostModule(module); // NOLINT
auto end = std::chrono::high_resolution_clock::now();
auto duration =
std::chrono::duration_cast<std::chrono::microseconds>(end - start);
if (FLAGS_cinn_debug) {
LOG(INFO) << "Time of SplitDeviceAndHostModule: ***** [ "
<< duration.count() << " ] ***** microseconds.";
}
auto& host_module = std::get<0>(_host_module_device_module_);
auto& device_module = std::get<1>(_host_module_device_module_);
VLOG(3) << "[CUDA] host module:\n" << host_module;
Expand All @@ -460,7 +478,15 @@ void Compiler::CompileCudaModule(const Module& module,
source_code = GetFileContent(file_path);
} else if (code.empty()) {
CodeGenCudaDev codegen(target_);
start = std::chrono::high_resolution_clock::now();
source_code = codegen.Compile(device_module);
end = std::chrono::high_resolution_clock::now();
duration =
std::chrono::duration_cast<std::chrono::microseconds>(end - start);
if (FLAGS_cinn_debug) {
LOG(INFO) << "Time of backend compiler compile device module: ***** [ "
<< duration.count() << " ] ***** microseconds.";
}
} else {
source_code = code;
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
// limitations under the License.

#include "paddle/cinn/hlir/dialect/operator/transforms/lowering_pass/utils.h"

#include <chrono>
#include "paddle/cinn/adt/generate_map_expr.h"
#include "paddle/cinn/hlir/dialect/operator/ir/attribute_storage.h"
#include "paddle/cinn/hlir/dialect/operator/ir/generate_shape_util.h"
Expand All @@ -28,6 +28,7 @@

PD_DECLARE_bool(cinn_enable_map_expr);
PD_DECLARE_bool(enable_cinn_compile_cache);
COMMON_DECLARE_bool(cinn_debug);

namespace cinn::dialect::ir::details {

Expand All @@ -44,8 +45,18 @@ std::unordered_map<std::string, ::pir::Attribute> GetJitKernelAttr(
return CompilationCache::Instance().GetKernelInfo(fusion_info);
};
const auto& CreateFromNewCompile = [&]() {
auto start = std::chrono::high_resolution_clock::now();
PirCompiler pir_compiler(cinn::common::DefaultDeviceTarget());
return pir_compiler.Build({group})[0];
auto res = pir_compiler.Build({group})[0];
auto end = std::chrono::high_resolution_clock::now();
auto duration =
std::chrono::duration_cast<std::chrono::microseconds>(end - start);
if (FLAGS_cinn_debug) {
LOG(INFO) << "Time of lowering and compiling group: ***** [ "
<< duration.count() << " ] ***** microseconds.";
}
return res;
// return pir_compiler.Build({group})[0];
};

if (FLAGS_enable_cinn_compile_cache) {
Expand Down
24 changes: 22 additions & 2 deletions paddle/cinn/hlir/framework/pir/compilation_task.cc
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
#pragma once

#include "paddle/cinn/hlir/framework/pir/compilation_task.h"

#include <chrono>
#include "paddle/cinn/backends/codegen_device_util.h"
#include "paddle/cinn/common/dim_expr_converter.h"
#include "paddle/cinn/common/target.h"
Expand All @@ -24,6 +24,9 @@
#include "paddle/cinn/hlir/framework/pir/utils.h"
#include "paddle/cinn/ir/utils/stmt_converter.h"
#include "paddle/common/enforce.h"

COMMON_DECLARE_bool(cinn_debug);

namespace cinn {
namespace hlir {
namespace framework {
Expand Down Expand Up @@ -190,8 +193,25 @@ void UnifyBroadcastGroupFuncArgs(
}

std::shared_ptr<pir::CompilationResult> CompilationTask::operator()() {
auto start = std::chrono::high_resolution_clock::now();
Lowering();
return CodegenAndJit();
auto end = std::chrono::high_resolution_clock::now();
auto duration =
std::chrono::duration_cast<std::chrono::microseconds>(end - start);
if (FLAGS_cinn_debug) {
LOG(INFO) << "Time of lowering: ***** [ " << duration.count()
<< " ] ***** microseconds.";
}
start = std::chrono::high_resolution_clock::now();
auto result = CodegenAndJit();
end = std::chrono::high_resolution_clock::now();
duration = std::chrono::duration_cast<std::chrono::microseconds>(end - start);
if (FLAGS_cinn_debug) {
LOG(INFO) << "Time of codegenandjit: ***** [ " << duration.count()
<< " ] ***** microseconds.";
}
return result;
// return CodegenAndJit();
}

void CompilationTask::Lowering() {
Expand Down
19 changes: 16 additions & 3 deletions paddle/cinn/hlir/framework/pir/op_lowering_impl.cc
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,7 @@
PD_DECLARE_bool(cinn_use_cuda_vectorize);
PD_DECLARE_bool(cinn_check_tensor_buffer_map);
PD_DECLARE_bool(cinn_longlong2int);
COMMON_DECLARE_bool(cinn_debug);
const int default_priority = 100;

namespace cinn {
Expand Down Expand Up @@ -118,7 +119,7 @@ BucketLoweredFuncsWrapper OpLowererImpl::BucketLower(
}

// =========== OpFusion ============

auto start = std::chrono::high_resolution_clock::now();
// VLOG(4) << "Bucket Lower output values is : " << group->output_values();
func_bodies = OperationFusion(ops,
func_bodies,
Expand All @@ -133,7 +134,6 @@ BucketLoweredFuncsWrapper OpLowererImpl::BucketLower(
for (auto value : group->GetInputOpValues()) {
fusion_group_args.insert(ValueName(value));
}

for (auto value : group->GetGroupOutputValues()) {
fusion_group_args.insert(ValueName(value));
}
Expand All @@ -155,9 +155,16 @@ BucketLoweredFuncsWrapper OpLowererImpl::BucketLower(
optim::CheckTensorBufferMap(func_bodies, "BucketLower OpFusion");
VLOG(3) << "OpFusion tensor-buffer map check succeed";
}
auto end = std::chrono::high_resolution_clock::now();
auto duration =
std::chrono::duration_cast<std::chrono::microseconds>(end - start);
if (FLAGS_cinn_debug) {
LOG(INFO) << "Time of OpFusion: ***** [ " << duration.count()
<< " ] ***** microseconds.";
}

// =========== CodeGen And Optimizer ================

start = std::chrono::high_resolution_clock::now();
// 2.Do group schedule.
ir::ModuleExpr mod_expr(func_bodies);
ir::IRSchedule ir_sch(
Expand Down Expand Up @@ -205,6 +212,12 @@ BucketLoweredFuncsWrapper OpLowererImpl::BucketLower(
}
VLOG(3) << "Schedule tensor-buffer map check succeed";
}
end = std::chrono::high_resolution_clock::now();
duration = std::chrono::duration_cast<std::chrono::microseconds>(end - start);
if (FLAGS_cinn_debug) {
LOG(INFO) << "Time of group schedule: ***** [ " << duration.count()
<< " ] ***** microseconds.";
}

// 3.Do post-processing,
// including preparing function args and temporary variables,
Expand Down
14 changes: 11 additions & 3 deletions paddle/cinn/hlir/framework/pir_compiler.cc
Original file line number Diff line number Diff line change
Expand Up @@ -13,12 +13,12 @@
// limitations under the License.

#include "paddle/cinn/hlir/framework/pir_compiler.h"
#include "paddle/cinn/ir/group_schedule/config/schedule_config_manager.h"

#include <chrono>
#include "paddle/cinn/common/shape_constraint.h"
#include "paddle/cinn/hlir/dialect/operator/transforms/lowering_pass/utils.h"
#include "paddle/cinn/hlir/framework/pir/broadcast_with_cf.h"
#include "paddle/cinn/hlir/framework/pir/utils.h"
#include "paddle/cinn/ir/group_schedule/config/schedule_config_manager.h"
#include "paddle/cinn/runtime/arch_device.h"
#include "paddle/cinn/utils/multi_threading.h"
#include "paddle/common/enforce.h"
Expand All @@ -27,6 +27,7 @@

PD_DECLARE_bool(enable_cinn_compile_cache);
PD_DECLARE_int64(cinn_compile_thread_num);
COMMON_DECLARE_bool(cinn_debug);

namespace cinn::hlir::framework {
class CompilationContextMapper {
Expand Down Expand Up @@ -107,7 +108,6 @@ std::shared_ptr<pir::CompilationResult> PirCompiler::Compile(
GroupCompilationContext* ctx) {
std::shared_ptr<pir::CompilationResult> compile_result;
CompilationTask task(ctx);

const auto& optional_broadcast_optimize_groups =
pir::GetBroadcastGroupListForOptimize(ctx->GetGroup());

Expand Down Expand Up @@ -146,8 +146,16 @@ std::shared_ptr<pir::CompilationResult> PirCompiler::Compile(
compile_result = task();
}

auto start = std::chrono::high_resolution_clock::now();
// Triggering llvm compilation in thread
compile_result->GetKernelInfo();
auto end = std::chrono::high_resolution_clock::now();
auto duration =
std::chrono::duration_cast<std::chrono::microseconds>(end - start);
if (FLAGS_cinn_debug) {
LOG(INFO) << "Time of llvm compile: ***** [ " << duration.count()
<< " ] ***** microseconds.";
}
return compile_result;
}

Expand Down
7 changes: 4 additions & 3 deletions paddle/cinn/ir/ir.cc
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
// limitations under the License.

#include "paddle/cinn/ir/ir.h"

#include <chrono>
#include <map>
#include <optional>
#include <regex>
Expand Down Expand Up @@ -723,8 +723,9 @@ Expr Store::Make(Expr tensor, Expr value, const std::vector<Expr> &indices) {
utils::GetCompatibleStoreLoadIndices(tensor.as_tensor_ref(), indices);

if (tensor->type() != Void()) {
node->set_type(
tensor->type().ElementOf().with_lanes(node->index().type().lanes()));
// node->set_type(
// tensor->type().ElementOf().with_lanes(node->index().type().lanes()));
Comment on lines +726 to +727
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

下一个PR把注释删除吧

node->set_type(tensor->type());
}
return Expr(node);
}
Expand Down
Loading