Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[CINN] add compile error handler #57198

Merged
merged 3 commits into from
Sep 18, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion paddle/cinn/auto_schedule/measure/simple_builder.cc
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ BuildResult SimpleBuilder::Build(const MeasureInput& input) {

BuildResult build_result;
build_result.compiled_scope = graph_compiler_->GetScope().get();
build_result.runtime_program = std::move(compiled_result.runtime_program);
build_result.runtime_program = std::move(compiled_result.RuntimeProgram());
return build_result;
}

Expand Down
49 changes: 32 additions & 17 deletions paddle/cinn/backends/compiler.cc
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@

#include "paddle/cinn/backends/llvm/runtime_symbol_registry.h"
#include "paddle/cinn/common/context.h"
#include "paddle/cinn/hlir/framework/graph_compiler_util.h"
#include "paddle/cinn/hlir/framework/visualize_helper.h"
#include "paddle/cinn/ir/utils/ir_printer.h"
#ifdef CINN_WITH_CUDA
Expand All @@ -39,6 +40,7 @@ PD_DECLARE_string(cinn_dump_group_instruction);
namespace cinn {
namespace backends {
using ir::Module;
using CompilationStatus = hlir::framework::CompilationStatus;

static constexpr int DebugLogMaxLen = 30000;

Expand Down Expand Up @@ -88,9 +90,13 @@ void CompilationInfoDumper::DumpLoweredFunc() {
if (FLAGS_cinn_dump_group_lowered_func.empty()) {
return;
}
for (int idx = 0; idx < info_.lowered_funcs.size(); ++idx) {
for (int idx = 0; idx < info_.Size(); ++idx) {
std::stringstream content;
content << info_.lowered_funcs[idx].front();
if (info_.Status(idx) > CompilationStatus::LOWERING_FAIL) {
content << info_.LoweredFuncs(idx).front();
} else {
content << "[No lowered func generated]\n\n" << info_.Message(idx);
}
Dump(FLAGS_cinn_dump_group_lowered_func,
idx,
"lowered_function.txt",
Expand All @@ -102,35 +108,44 @@ void CompilationInfoDumper::DumpSourceCode() {
if (FLAGS_cinn_dump_group_source_code.empty()) {
return;
}
for (int idx = 0; idx < info_.source_codes.size(); ++idx) {
Dump(FLAGS_cinn_dump_group_source_code,
idx,
"source_code.cu",
info_.source_codes[idx]);
for (int idx = 0; idx < info_.Size(); ++idx) {
std::string dump_str;
if (info_.Status(idx) > CompilationStatus::CODEGEN_JIT_FAIL) {
dump_str = info_.SourceCode(idx);
} else {
dump_str = "[No source code generated]\n\n" + info_.Message(idx);
}
Dump(FLAGS_cinn_dump_group_source_code, idx, "source_code.cu", dump_str);
}
}

void CompilationInfoDumper::DumpPtxCode() {
if (FLAGS_cinn_dump_group_ptx.empty()) {
return;
}
for (int idx = 0; idx < info_.source_ptxs.size(); ++idx) {
Dump(FLAGS_cinn_dump_group_ptx,
idx,
"source_ptx.ptx",
info_.source_ptxs[idx]);
for (int idx = 0; idx < info_.Size(); ++idx) {
std::string dump_str;
if (info_.Status(idx) > CompilationStatus::CODEGEN_JIT_FAIL) {
dump_str = info_.SourcePtx(idx);
} else {
dump_str = "[No source ptxs generated]\n\n" + info_.Message(idx);
}
Dump(FLAGS_cinn_dump_group_ptx, idx, "source_ptx.ptx", dump_str);
}
}

void CompilationInfoDumper::DumpInstruction() {
if (FLAGS_cinn_dump_group_instruction.empty()) {
return;
}
for (int idx = 0; idx < info_.instructions.size(); ++idx) {
Dump(FLAGS_cinn_dump_group_instruction,
idx,
"instruction.txt",
info_.instructions[idx]->DumpInstruction());
for (int idx = 0; idx < info_.RuntimeInstructions().size(); ++idx) {
std::string dump_str;
if (info_.RuntimeInstruction(idx).get() != nullptr) {
dump_str = info_.RuntimeInstruction(idx)->DumpInstruction();
} else {
dump_str = "[No instruction generated]\n\n" + info_.Message(idx);
}
Dump(FLAGS_cinn_dump_group_instruction, idx, "instruction.txt", dump_str);
}
}

Expand Down
3 changes: 2 additions & 1 deletion paddle/cinn/hlir/framework/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,8 @@ gather_srcs(
op_lowering_util.cc
op_lowering_impl.cc
accuracy_checker.cc
visualize_helper.cc)
visualize_helper.cc
compile_error.cc)

# TODO(Aurelius84): new_ir_compiler depends on pd_dialect and could
# not found under CINN_ONLY mode
Expand Down
41 changes: 41 additions & 0 deletions paddle/cinn/hlir/framework/compile_error.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,41 @@
// Copyright (c) 2023 CINN Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#include "paddle/cinn/hlir/framework/compile_error.h"
#include "paddle/cinn/utils/enum_string.h"

namespace cinn {
namespace hlir {
namespace framework {

std::string CompileErrorHandler::GeneralErrorMessage() const {
std::ostringstream os;
os << "[CompileError] An error occurred during compilation with the error "
"code: "
<< utils::Enum2String(status_) << std::endl;
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

我看了下Enum2String的实现,还挺复杂的,如果只是为了把enum转成相应的string,是不是可以简化一下实现,还是说有其他的考虑

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

该实现可以支持任意枚举类转为String,未来若有其他需求可以复用,无需每个枚举类自己定义ToString().

os << "(at " << file_ << " : " << line_ << ")" << std::endl;
os << indent_str_ << "[Error info] " << this->err_msg_ << std::endl;
return os.str();
}

std::string CompileErrorHandler::DetailedErrorMessage() const {
std::ostringstream os;
os << GeneralErrorMessage();
os << indent_str_ << "[Detail info] " << detail_info_ << std::endl;
return os.str();
}

} // namespace framework
} // namespace hlir
} // namespace cinn
68 changes: 68 additions & 0 deletions paddle/cinn/hlir/framework/compile_error.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,68 @@
// Copyright (c) 2023 CINN Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#pragma once

#include "paddle/cinn/hlir/framework/graph_compiler_util.h"
#include "paddle/cinn/utils/error.h"

namespace cinn {
namespace hlir {
namespace framework {

/**
* This handler is used to deal with the errors during the compilation process
*/
class CompileErrorHandler : public utils::ErrorHandler {
public:
/**
* \brief constructor
* \param err_msg the error message
*/
explicit CompileErrorHandler(const CompilationStatus& status,
const std::string& err_msg,
const std::string& detail_info,
const char* file,
int line)
: status_(status),
err_msg_(err_msg),
detail_info_(detail_info),
file_(file),
line_(line) {}

/**
* \brief Returns a short error message corresponding to the kGeneral error
* level.
*/
std::string GeneralErrorMessage() const;

/**
* \brief Returns a detailed error message corresponding to the kDetailed
* error level.
*/
std::string DetailedErrorMessage() const;

CompilationStatus Status() const { return status_; }

private:
CompilationStatus status_;
std::string err_msg_;
std::string detail_info_;
const char* file_;
int line_;
};

} // namespace framework
} // namespace hlir
} // namespace cinn
13 changes: 7 additions & 6 deletions paddle/cinn/hlir/framework/graph_compiler.cc
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@
#include "paddle/cinn/lang/lower.h"
#include "paddle/cinn/optim/transform_gpu_forloop.h"
#include "paddle/cinn/poly/stage.h"
#include "paddle/cinn/utils/enum_string.h"
#include "paddle/cinn/utils/profiler.h"

namespace cinn {
Expand All @@ -44,7 +45,7 @@ std::unique_ptr<Program> GraphCompiler::Build(const std::string& code) {
compilation_context_.with_instantiate_variables = true;

auto&& result = Build(&compilation_context_);
return std::move(result.runtime_program);
return result.RuntimeProgram();
}

CompilationResult GraphCompiler::Build(CompilationContext* context) {
Expand All @@ -64,22 +65,22 @@ CompilationResult GraphCompiler::Build(CompilationContext* context) {
parallel_compiler_ = std::make_shared<ParallelCompiler>(context);
CompilationResult result = (*parallel_compiler_.get())();

if (context->stage != CompilationStage::DEFAULT) {
if (context->stage != CompilationStage::DEFAULT || !result.IsSuccess()) {
return result;
}

if (context->remove_unused_variables) {
RemoveInvalidVariables(context, result.instructions);
RemoveInvalidVariables(context, result.RuntimeInstructions());
}

if (context->with_buffer_handle_instruction_inserted) {
VLOG(3) << "option.with_buffer_handle_instruction_inserted enable";
InsertBufferHandlers(context, &result.instructions);
InsertBufferHandlers(context, &result.instructions_);
}
VLOG(2) << "Compile With Parallel Compiler Done!";

result.runtime_program =
std::make_unique<Program>(context->scope, std::move(result.instructions));
result.SetRuntimeProgram(std::make_unique<Program>(
context->scope, std::move(result.instructions_)));
return result;
}

Expand Down
10 changes: 5 additions & 5 deletions paddle/cinn/hlir/framework/graph_compiler_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -75,7 +75,7 @@ TEST(GraphCompilerTest, TestInsertBufferHandlers) {
GraphCompiler gc_disable(context_disable);
// disable with_buffer_handle_instruction_inserted: only 1 instruction
auto runtime_program_disable =
gc_disable.Build(&context_disable).runtime_program;
gc_disable.Build(&context_disable).RuntimeProgram();
ASSERT_EQ(runtime_program_disable->size(), 1);
const auto& computation_instr_disable =
runtime_program_disable->GetRunInstructions().front();
Expand All @@ -87,7 +87,7 @@ TEST(GraphCompilerTest, TestInsertBufferHandlers) {
context_enable.with_buffer_handle_instruction_inserted = true;
GraphCompiler gc_enable(context_enable);
auto runtime_program_enable =
gc_enable.Build(&context_enable).runtime_program;
gc_enable.Build(&context_enable).RuntimeProgram();
const auto& instructions = runtime_program_enable->GetRunInstructions();
ASSERT_EQ(instructions.size(), 3);

Expand Down Expand Up @@ -254,7 +254,7 @@ TEST(GraphCompilerTest, TestLowering) {
GraphCompiler gc(context);
CompilationResult result = gc.Lowering();

ASSERT_EQ(result.status, CompilationStatus::SUCCESS);
ASSERT_EQ(result.Status(), CompilationStatus::SUCCESS);
}

TEST(GraphCompilerTest, TestCodegenAndJit) {
Expand All @@ -274,7 +274,7 @@ TEST(GraphCompilerTest, TestCodegenAndJit) {
GraphCompiler gc(context);
CompilationResult result = gc.CodegenAndJit();

ASSERT_EQ(result.status, CompilationStatus::SUCCESS);
ASSERT_EQ(result.Status(), CompilationStatus::SUCCESS);
}

TEST(GraphCompilerTest, TestBuildInstruction) {
Expand All @@ -294,7 +294,7 @@ TEST(GraphCompilerTest, TestBuildInstruction) {
GraphCompiler gc(context);
CompilationResult result = gc.BuildInstruction();

ASSERT_EQ(result.status, CompilationStatus::SUCCESS);
ASSERT_EQ(result.Status(), CompilationStatus::SUCCESS);
}

#endif
Expand Down
Loading