Skip to content

Commit

Permalink
Cinn error refactor (PaddlePaddle#55544)
Browse files Browse the repository at this point in the history
* Refactor the error message system

* fix header

* fix compile
  • Loading branch information
ZzSean authored and wyf committed Aug 30, 2023
1 parent 22cd8dc commit 6fb9a90
Show file tree
Hide file tree
Showing 9 changed files with 223 additions and 213 deletions.
4 changes: 2 additions & 2 deletions paddle/cinn/backends/ir_schedule_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -177,7 +177,7 @@ void TestSplitThrow() {
std::vector<Expr> vec_ast{ast_expr};
ir::ModuleExpr mod_expr(vec_ast);
ir::IRSchedule ir_sch(
mod_expr, -1, false, ir::ScheduleErrorMessageLevel::kGeneral);
mod_expr, -1, false, utils::ErrorMessageLevel::kGeneral);
auto fused = ir_sch.Fuse("B", {0, 1});
// statement that cause the exception
auto splited = ir_sch.Split(fused, {-1, -1});
Expand All @@ -196,7 +196,7 @@ void TestSplitThrow() {
auto source_code = codegen.Compile(module, CodeGenC::OutputKind::CImpl);
}
TEST(IrSchedule, split_throw) {
ASSERT_THROW(TestSplitThrow(), ir::enforce::EnforceNotMet);
ASSERT_THROW(TestSplitThrow(), utils::enforce::EnforceNotMet);
}

TEST(IrSchedule, reorder1) {
Expand Down
26 changes: 12 additions & 14 deletions paddle/cinn/ir/schedule/ir_schedule.cc
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@
#include "paddle/cinn/optim/replace_var_with_expr.h"
#include "paddle/cinn/utils/string.h"

DECLARE_int32(cinn_schedule_error_message_level);
DECLARE_int32(cinn_error_message_level);

namespace cinn {
namespace ir {
Expand All @@ -54,12 +54,11 @@ class ScheduleImpl {
ScheduleImpl() = default;
explicit ScheduleImpl(const ModuleExpr& module_expr,
bool debug_flag = false,
ScheduleErrorMessageLevel err_msg_level =
ScheduleErrorMessageLevel::kGeneral)
utils::ErrorMessageLevel err_msg_level =
utils::ErrorMessageLevel::kGeneral)
: module_expr_(module_expr), debug_flag_(debug_flag) {
err_msg_level_ = static_cast<ScheduleErrorMessageLevel>(
FLAGS_cinn_schedule_error_message_level ||
static_cast<int>(err_msg_level));
err_msg_level_ = static_cast<utils::ErrorMessageLevel>(
FLAGS_cinn_error_message_level || static_cast<int>(err_msg_level));
}
explicit ScheduleImpl(ModuleExpr&& module_expr)
: module_expr_(std::move(module_expr)) {}
Expand Down Expand Up @@ -138,8 +137,7 @@ class ScheduleImpl {

ModuleExpr module_expr_;
bool debug_flag_{false};
ScheduleErrorMessageLevel err_msg_level_ =
ScheduleErrorMessageLevel::kGeneral;
utils::ErrorMessageLevel err_msg_level_ = utils::ErrorMessageLevel::kGeneral;
};

/** \brief A macro that guards the beginning of each implementation of schedule
Expand All @@ -152,10 +150,10 @@ class ScheduleImpl {
* @param err_msg_level A ScheduleErrorMessageLevel enum, level of error message
* printing
*/
#define CINN_IR_SCHEDULE_END(primitive, err_msg_level) \
} \
catch (const IRScheduleErrorHandler& err_hanlder) { \
CINN_THROW(err_hanlder.FormatErrorMessage(primitive, err_msg_level)); \
#define CINN_IR_SCHEDULE_END(err_msg_level) \
} \
catch (const utils::ErrorHandler& err_hanlder) { \
CINN_THROW(err_hanlder.FormatErrorMessage(err_msg_level)); \
}

std::vector<Expr> ScheduleImpl::Split(const Expr& loop,
Expand All @@ -177,7 +175,7 @@ std::vector<Expr> ScheduleImpl::Split(const Expr& loop,
std::vector<int> processed_factors;
CINN_IR_SCHEDULE_BEGIN();
processed_factors = ValidateFactors(factors, tot_extent, this->module_expr_);
CINN_IR_SCHEDULE_END("split", this->err_msg_level_);
CINN_IR_SCHEDULE_END(this->err_msg_level_);
int prod_size = std::accumulate(processed_factors.begin(),
processed_factors.end(),
1,
Expand Down Expand Up @@ -2316,7 +2314,7 @@ IRSchedule::IRSchedule() {}
IRSchedule::IRSchedule(const ModuleExpr& module_expr,
utils::LinearRandomEngine::StateType rand_seed,
bool debug_flag,
ScheduleErrorMessageLevel err_msg_level) {
utils::ErrorMessageLevel err_msg_level) {
impl_ =
std::make_unique<ScheduleImpl>(module_expr, debug_flag, err_msg_level);
this->InitSeed(rand_seed);
Expand Down
19 changes: 3 additions & 16 deletions paddle/cinn/ir/schedule/ir_schedule.h
Original file line number Diff line number Diff line change
Expand Up @@ -24,25 +24,12 @@
#include "paddle/cinn/ir/schedule/schedule_desc.h"
#include "paddle/cinn/ir/tensor.h"
#include "paddle/cinn/ir/utils/ir_mutator.h"
#include "paddle/cinn/utils/error.h"
#include "paddle/cinn/utils/random_engine.h"

namespace cinn {
namespace ir {

/**
* \brief Indicates the level of printing error message in the current Schedule
*/
enum class ScheduleErrorMessageLevel : int32_t {
/** \brief Print an error message in short mode.
* Short mode shows which and where the schedule error happens*/
kGeneral = 0,
/** \brief Print an error message in detailed mode.
* Detailed mode shows which and where the schedule error happens, and the
* schedule input parameters.
*/
kDetailed = 1,
};

/**
* A struct representing a module that contains Expr. This struct is only used
* in Schedule process.
Expand Down Expand Up @@ -85,8 +72,8 @@ class IRSchedule {
explicit IRSchedule(const ModuleExpr& modexpr,
utils::LinearRandomEngine::StateType rand_seed = -1,
bool debug_flag = false,
ScheduleErrorMessageLevel err_msg_level =
ScheduleErrorMessageLevel::kGeneral);
utils::ErrorMessageLevel err_msg_level =
utils::ErrorMessageLevel::kGeneral);
IRSchedule(ir::ModuleExpr&& mod_expr,
ScheduleDesc&& trace,
utils::LinearRandomEngine::StateType rand_seed = -1);
Expand Down
45 changes: 5 additions & 40 deletions paddle/cinn/ir/schedule/ir_schedule_error.cc
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,11 @@ namespace cinn {
namespace ir {

std::string IRScheduleErrorHandler::GeneralErrorMessage() const {
return this->err_msg_;
std::ostringstream os;
os << "[IRScheduleError] An error occurred in the scheduel primitive < "
<< this->primitive_ << " >. " << std::endl;
os << this->err_msg_;
return os.str();
}

std::string IRScheduleErrorHandler::DetailedErrorMessage() const {
Expand All @@ -31,44 +35,5 @@ std::string IRScheduleErrorHandler::DetailedErrorMessage() const {
return os.str();
}

std::string IRScheduleErrorHandler::FormatErrorMessage(
const std::string& primitive,
const ScheduleErrorMessageLevel& err_msg_level) const {
std::ostringstream os;
std::string err_msg = err_msg_level == ScheduleErrorMessageLevel::kDetailed
? DetailedErrorMessage()
: GeneralErrorMessage();

os << "[IRScheduleError] An error occurred in the scheduel primitive <"
<< primitive << ">. " << std::endl;
os << "[Error info] " << err_msg;
return os.str();
}

std::string NegativeFactorErrorMessage(const int64_t& factor,
const size_t& idx) {
std::ostringstream os;
os << "The params in factors of Split should be positive. However, the "
"factor at position "
<< idx << " is " << factor << std::endl;
return os.str();
}

std::string InferFactorErrorMessage() {
std::ostringstream os;
os << "The params in factors of Split should not be less than -1 or have "
"more than one -1!"
<< std::endl;
return os.str();
}

std::string FactorProductErrorMessage() {
std::ostringstream os;
os << "In Split, the factors' product should be not larger than or equal "
"to original loop's extent!"
<< std::endl;
return os.str();
}

} // namespace ir
} // namespace cinn
131 changes: 6 additions & 125 deletions paddle/cinn/ir/schedule/ir_schedule_error.h
Original file line number Diff line number Diff line change
Expand Up @@ -14,130 +14,25 @@

#pragma once

#ifdef __GNUC__
#include <cxxabi.h> // for __cxa_demangle
#endif // __GNUC__

#if !defined(_WIN32)
#include <dlfcn.h> // dladdr
#include <unistd.h> // sleep, usleep
#else // _WIN32
#ifndef NOMINMAX
#define NOMINMAX // msvc max/min macro conflict with std::min/max
#endif
#include <windows.h> // GetModuleFileName, Sleep
#endif

#if !defined(_WIN32) && !defined(PADDLE_WITH_MUSL)
#include <execinfo.h>
#endif

#include <iostream>
#include <memory>
#include <sstream>
#include <stdexcept>
#include <string>
#include <vector>
#include "paddle/cinn/ir/schedule/ir_schedule.h"

namespace cinn {
namespace ir {

namespace enforce {

#ifdef __GNUC__
inline std::string demangle(std::string name) {
int status = -4; // some arbitrary value to eliminate the compiler warning
std::unique_ptr<char, void (*)(void*)> res{
abi::__cxa_demangle(name.c_str(), NULL, NULL, &status), std::free};
return (status == 0) ? res.get() : name;
}
#else
inline std::string demangle(std::string name) { return name; }
#endif

static std::string GetErrorSumaryString(const std::string& what,
const char* file,
int line) {
std::ostringstream sout;
sout << "\n----------------------\nError Message "
"Summary:\n----------------------\n";
sout << what << "(at " << file << " : " << line << ")" << std::endl;
return sout.str();
}

static std::string GetCurrentTraceBackString() {
std::ostringstream sout;
sout << "\n\n--------------------------------------\n";
sout << "C++ Traceback (most recent call last):";
sout << "\n--------------------------------------\n";
#if !defined(_WIN32) && !defined(PADDLE_WITH_MUSL)
static constexpr int TRACE_STACK_LIMIT = 100;

void* call_stack[TRACE_STACK_LIMIT];
auto size = backtrace(call_stack, TRACE_STACK_LIMIT);
auto symbols = backtrace_symbols(call_stack, size);
Dl_info info;
int idx = 0;
int end_idx = 0;
for (int i = size - 1; i >= end_idx; --i) {
if (dladdr(call_stack[i], &info) && info.dli_sname) {
auto demangled = demangle(info.dli_sname);
std::string path(info.dli_fname);
// C++ traceback info are from core.so
if (path.substr(path.length() - 3).compare(".so") == 0) {
sout << idx++ << " " << demangled << "\n";
}
}
}
free(symbols);
#else
sout << "Not support stack backtrace yet.\n";
#endif
return sout.str();
}

static std::string GetTraceBackString(const std::string& what,
const char* file,
int line) {
return GetCurrentTraceBackString() + GetErrorSumaryString(what, file, line);
}

struct EnforceNotMet : public std::exception {
public:
EnforceNotMet(const std::string& str, const char* file, int line)
: err_str_(GetTraceBackString(str, file, line)) {}

const char* what() const noexcept override { return err_str_.c_str(); }

private:
std::string err_str_;
};

#define CINN_THROW(...) \
do { \
try { \
throw enforce::EnforceNotMet(__VA_ARGS__, __FILE__, __LINE__); \
} catch (const std::exception& e) { \
std::cout << e.what() << std::endl; \
throw; \
} \
} while (0)
} // namespace enforce

/**
* This handler is dealing with the errors happen in in the current
* Scheduling.
*/
class IRScheduleErrorHandler {
class IRScheduleErrorHandler : public utils::ErrorHandler {
public:
/**
* \brief constructor
* \param err_msg the error message
*/
explicit IRScheduleErrorHandler(const std::string& err_msg,
explicit IRScheduleErrorHandler(const std::string& primitive,
const std::string& err_msg,
const ModuleExpr& module_expr)
: err_msg_(err_msg), module_expr_(module_expr) {}
: primitive_(primitive), err_msg_(err_msg), module_expr_(module_expr) {}

/**
* \brief Returns a short error message corresponding to the kGeneral error
Expand All @@ -151,25 +46,11 @@ class IRScheduleErrorHandler {
*/
std::string DetailedErrorMessage() const;

/**
* \brief Returns a detailed error message corresponding to the kDetailed
* error level.
*/
std::string FormatErrorMessage(
const std::string& primitive,
const ScheduleErrorMessageLevel& err_msg_level) const;

private:
ModuleExpr module_expr_;
std::string primitive_;
std::string err_msg_;
ModuleExpr module_expr_;
};

std::string NegativeFactorErrorMessage(const int64_t& factor,
const size_t& idx);

std::string InferFactorErrorMessage();

std::string FactorProductErrorMessage();

} // namespace ir
} // namespace cinn
27 changes: 22 additions & 5 deletions paddle/cinn/ir/schedule/ir_schedule_util.cc
Original file line number Diff line number Diff line change
Expand Up @@ -222,6 +222,7 @@ void ReplaceExpr(Expr* source,
std::vector<int> ValidateFactors(const std::vector<int>& factors,
int total_extent,
const ModuleExpr& module_expr) {
const std::string primitive = "split";
CHECK(!factors.empty())
<< "The factors param of Split should not be empty! Please check.";
bool has_minus_one = false;
Expand All @@ -230,11 +231,19 @@ std::vector<int> ValidateFactors(const std::vector<int>& factors,
for (auto& i : factors) {
idx++;
if (i == 0 || i < -1) {
throw IRScheduleErrorHandler(NegativeFactorErrorMessage(i, idx),
module_expr);
std::ostringstream os;
os << "The params in factors of Split should be positive. However, the "
"factor at position "
<< idx << " is " << i << std::endl;
throw IRScheduleErrorHandler(primitive, os.str(), module_expr);
} else if (i == -1) {
if (has_minus_one) {
throw IRScheduleErrorHandler(InferFactorErrorMessage(), module_expr);
std::ostringstream os;
os << "The params in factors of Split should not be less than -1 or "
"have "
"more than one -1!"
<< std::endl;
throw IRScheduleErrorHandler(primitive, os.str(), module_expr);
}
has_minus_one = true;
} else {
Expand All @@ -244,12 +253,20 @@ std::vector<int> ValidateFactors(const std::vector<int>& factors,
std::vector<int> validated_factors = factors;
if (!has_minus_one) {
if (product < total_extent) {
throw IRScheduleErrorHandler(FactorProductErrorMessage(), module_expr);
std::ostringstream os;
os << "In Split, the factors' product should be not larger than or equal "
"to original loop's extent!"
<< std::endl;
throw IRScheduleErrorHandler(primitive, os.str(), module_expr);
}
return validated_factors;
} else {
if (product > total_extent) {
throw IRScheduleErrorHandler(FactorProductErrorMessage(), module_expr);
std::ostringstream os;
os << "In Split, the factors' product should be not larger than or equal "
"to original loop's extent!"
<< std::endl;
throw IRScheduleErrorHandler(primitive, os.str(), module_expr);
}
int minus_one_candidate = static_cast<int>(
ceil(static_cast<double>(total_extent) / static_cast<double>(product)));
Expand Down
Loading

0 comments on commit 6fb9a90

Please sign in to comment.