Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Cinn error refactor #55544

Merged
merged 4 commits into from
Jul 26, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions paddle/cinn/backends/ir_schedule_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -177,7 +177,7 @@ void TestSplitThrow() {
std::vector<Expr> vec_ast{ast_expr};
ir::ModuleExpr mod_expr(vec_ast);
ir::IRSchedule ir_sch(
mod_expr, -1, false, ir::ScheduleErrorMessageLevel::kGeneral);
mod_expr, -1, false, utils::ErrorMessageLevel::kGeneral);
auto fused = ir_sch.Fuse("B", {0, 1});
// statement that cause the exception
auto splited = ir_sch.Split(fused, {-1, -1});
Expand All @@ -196,7 +196,7 @@ void TestSplitThrow() {
auto source_code = codegen.Compile(module, CodeGenC::OutputKind::CImpl);
}
TEST(IrSchedule, split_throw) {
ASSERT_THROW(TestSplitThrow(), ir::enforce::EnforceNotMet);
ASSERT_THROW(TestSplitThrow(), utils::enforce::EnforceNotMet);
}

TEST(IrSchedule, reorder1) {
Expand Down
26 changes: 12 additions & 14 deletions paddle/cinn/ir/schedule/ir_schedule.cc
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@
#include "paddle/cinn/optim/replace_var_with_expr.h"
#include "paddle/cinn/utils/string.h"

DECLARE_int32(cinn_schedule_error_message_level);
DECLARE_int32(cinn_error_message_level);

namespace cinn {
namespace ir {
Expand All @@ -54,12 +54,11 @@ class ScheduleImpl {
ScheduleImpl() = default;
explicit ScheduleImpl(const ModuleExpr& module_expr,
bool debug_flag = false,
ScheduleErrorMessageLevel err_msg_level =
ScheduleErrorMessageLevel::kGeneral)
utils::ErrorMessageLevel err_msg_level =
utils::ErrorMessageLevel::kGeneral)
: module_expr_(module_expr), debug_flag_(debug_flag) {
err_msg_level_ = static_cast<ScheduleErrorMessageLevel>(
FLAGS_cinn_schedule_error_message_level ||
static_cast<int>(err_msg_level));
err_msg_level_ = static_cast<utils::ErrorMessageLevel>(
FLAGS_cinn_error_message_level || static_cast<int>(err_msg_level));
}
explicit ScheduleImpl(ModuleExpr&& module_expr)
: module_expr_(std::move(module_expr)) {}
Expand Down Expand Up @@ -138,8 +137,7 @@ class ScheduleImpl {

ModuleExpr module_expr_;
bool debug_flag_{false};
ScheduleErrorMessageLevel err_msg_level_ =
ScheduleErrorMessageLevel::kGeneral;
utils::ErrorMessageLevel err_msg_level_ = utils::ErrorMessageLevel::kGeneral;
};

/** \brief A macro that guards the beginning of each implementation of schedule
Expand All @@ -152,10 +150,10 @@ class ScheduleImpl {
* @param err_msg_level A ScheduleErrorMessageLevel enum, level of error message
* printing
*/
#define CINN_IR_SCHEDULE_END(primitive, err_msg_level) \
} \
catch (const IRScheduleErrorHandler& err_hanlder) { \
CINN_THROW(err_hanlder.FormatErrorMessage(primitive, err_msg_level)); \
#define CINN_IR_SCHEDULE_END(err_msg_level) \
} \
catch (const utils::ErrorHandler& err_hanlder) { \
CINN_THROW(err_hanlder.FormatErrorMessage(err_msg_level)); \
}

std::vector<Expr> ScheduleImpl::Split(const Expr& loop,
Expand All @@ -177,7 +175,7 @@ std::vector<Expr> ScheduleImpl::Split(const Expr& loop,
std::vector<int> processed_factors;
CINN_IR_SCHEDULE_BEGIN();
processed_factors = ValidateFactors(factors, tot_extent, this->module_expr_);
CINN_IR_SCHEDULE_END("split", this->err_msg_level_);
CINN_IR_SCHEDULE_END(this->err_msg_level_);
int prod_size = std::accumulate(processed_factors.begin(),
processed_factors.end(),
1,
Expand Down Expand Up @@ -2316,7 +2314,7 @@ IRSchedule::IRSchedule() {}
IRSchedule::IRSchedule(const ModuleExpr& module_expr,
utils::LinearRandomEngine::StateType rand_seed,
bool debug_flag,
ScheduleErrorMessageLevel err_msg_level) {
utils::ErrorMessageLevel err_msg_level) {
impl_ =
std::make_unique<ScheduleImpl>(module_expr, debug_flag, err_msg_level);
this->InitSeed(rand_seed);
Expand Down
19 changes: 3 additions & 16 deletions paddle/cinn/ir/schedule/ir_schedule.h
Original file line number Diff line number Diff line change
Expand Up @@ -24,25 +24,12 @@
#include "paddle/cinn/ir/schedule/schedule_desc.h"
#include "paddle/cinn/ir/tensor.h"
#include "paddle/cinn/ir/utils/ir_mutator.h"
#include "paddle/cinn/utils/error.h"
#include "paddle/cinn/utils/random_engine.h"

namespace cinn {
namespace ir {

/**
* \brief Indicates the level of printing error message in the current Schedule
*/
enum class ScheduleErrorMessageLevel : int32_t {
/** \brief Print an error message in short mode.
* Short mode shows which and where the schedule error happens*/
kGeneral = 0,
/** \brief Print an error message in detailed mode.
* Detailed mode shows which and where the schedule error happens, and the
* schedule input parameters.
*/
kDetailed = 1,
};

/**
* A struct representing a module that contains Expr. This struct is only used
* in Schedule process.
Expand Down Expand Up @@ -85,8 +72,8 @@ class IRSchedule {
explicit IRSchedule(const ModuleExpr& modexpr,
utils::LinearRandomEngine::StateType rand_seed = -1,
bool debug_flag = false,
ScheduleErrorMessageLevel err_msg_level =
ScheduleErrorMessageLevel::kGeneral);
utils::ErrorMessageLevel err_msg_level =
utils::ErrorMessageLevel::kGeneral);
IRSchedule(ir::ModuleExpr&& mod_expr,
ScheduleDesc&& trace,
utils::LinearRandomEngine::StateType rand_seed = -1);
Expand Down
45 changes: 5 additions & 40 deletions paddle/cinn/ir/schedule/ir_schedule_error.cc
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,11 @@ namespace cinn {
namespace ir {

std::string IRScheduleErrorHandler::GeneralErrorMessage() const {
return this->err_msg_;
std::ostringstream os;
os << "[IRScheduleError] An error occurred in the scheduel primitive < "
<< this->primitive_ << " >. " << std::endl;
os << this->err_msg_;
return os.str();
}

std::string IRScheduleErrorHandler::DetailedErrorMessage() const {
Expand All @@ -31,44 +35,5 @@ std::string IRScheduleErrorHandler::DetailedErrorMessage() const {
return os.str();
}

std::string IRScheduleErrorHandler::FormatErrorMessage(
const std::string& primitive,
const ScheduleErrorMessageLevel& err_msg_level) const {
std::ostringstream os;
std::string err_msg = err_msg_level == ScheduleErrorMessageLevel::kDetailed
? DetailedErrorMessage()
: GeneralErrorMessage();

os << "[IRScheduleError] An error occurred in the scheduel primitive <"
<< primitive << ">. " << std::endl;
os << "[Error info] " << err_msg;
return os.str();
}

std::string NegativeFactorErrorMessage(const int64_t& factor,
const size_t& idx) {
std::ostringstream os;
os << "The params in factors of Split should be positive. However, the "
"factor at position "
<< idx << " is " << factor << std::endl;
return os.str();
}

std::string InferFactorErrorMessage() {
std::ostringstream os;
os << "The params in factors of Split should not be less than -1 or have "
"more than one -1!"
<< std::endl;
return os.str();
}

std::string FactorProductErrorMessage() {
std::ostringstream os;
os << "In Split, the factors' product should be not larger than or equal "
"to original loop's extent!"
<< std::endl;
return os.str();
}

} // namespace ir
} // namespace cinn
131 changes: 6 additions & 125 deletions paddle/cinn/ir/schedule/ir_schedule_error.h
Original file line number Diff line number Diff line change
Expand Up @@ -14,130 +14,25 @@

#pragma once

#ifdef __GNUC__
#include <cxxabi.h> // for __cxa_demangle
#endif // __GNUC__

#if !defined(_WIN32)
#include <dlfcn.h> // dladdr
#include <unistd.h> // sleep, usleep
#else // _WIN32
#ifndef NOMINMAX
#define NOMINMAX // msvc max/min macro conflict with std::min/max
#endif
#include <windows.h> // GetModuleFileName, Sleep
#endif

#if !defined(_WIN32) && !defined(PADDLE_WITH_MUSL)
#include <execinfo.h>
#endif

#include <iostream>
#include <memory>
#include <sstream>
#include <stdexcept>
#include <string>
#include <vector>
#include "paddle/cinn/ir/schedule/ir_schedule.h"

namespace cinn {
namespace ir {

namespace enforce {

#ifdef __GNUC__
inline std::string demangle(std::string name) {
int status = -4; // some arbitrary value to eliminate the compiler warning
std::unique_ptr<char, void (*)(void*)> res{
abi::__cxa_demangle(name.c_str(), NULL, NULL, &status), std::free};
return (status == 0) ? res.get() : name;
}
#else
inline std::string demangle(std::string name) { return name; }
#endif

static std::string GetErrorSumaryString(const std::string& what,
const char* file,
int line) {
std::ostringstream sout;
sout << "\n----------------------\nError Message "
"Summary:\n----------------------\n";
sout << what << "(at " << file << " : " << line << ")" << std::endl;
return sout.str();
}

static std::string GetCurrentTraceBackString() {
std::ostringstream sout;
sout << "\n\n--------------------------------------\n";
sout << "C++ Traceback (most recent call last):";
sout << "\n--------------------------------------\n";
#if !defined(_WIN32) && !defined(PADDLE_WITH_MUSL)
static constexpr int TRACE_STACK_LIMIT = 100;

void* call_stack[TRACE_STACK_LIMIT];
auto size = backtrace(call_stack, TRACE_STACK_LIMIT);
auto symbols = backtrace_symbols(call_stack, size);
Dl_info info;
int idx = 0;
int end_idx = 0;
for (int i = size - 1; i >= end_idx; --i) {
if (dladdr(call_stack[i], &info) && info.dli_sname) {
auto demangled = demangle(info.dli_sname);
std::string path(info.dli_fname);
// C++ traceback info are from core.so
if (path.substr(path.length() - 3).compare(".so") == 0) {
sout << idx++ << " " << demangled << "\n";
}
}
}
free(symbols);
#else
sout << "Not support stack backtrace yet.\n";
#endif
return sout.str();
}

static std::string GetTraceBackString(const std::string& what,
const char* file,
int line) {
return GetCurrentTraceBackString() + GetErrorSumaryString(what, file, line);
}

struct EnforceNotMet : public std::exception {
public:
EnforceNotMet(const std::string& str, const char* file, int line)
: err_str_(GetTraceBackString(str, file, line)) {}

const char* what() const noexcept override { return err_str_.c_str(); }

private:
std::string err_str_;
};

#define CINN_THROW(...) \
do { \
try { \
throw enforce::EnforceNotMet(__VA_ARGS__, __FILE__, __LINE__); \
} catch (const std::exception& e) { \
std::cout << e.what() << std::endl; \
throw; \
} \
} while (0)
} // namespace enforce

/**
* This handler is dealing with the errors happen in in the current
* Scheduling.
*/
class IRScheduleErrorHandler {
class IRScheduleErrorHandler : public utils::ErrorHandler {
public:
/**
* \brief constructor
* \param err_msg the error message
*/
explicit IRScheduleErrorHandler(const std::string& err_msg,
explicit IRScheduleErrorHandler(const std::string& primitive,
const std::string& err_msg,
const ModuleExpr& module_expr)
: err_msg_(err_msg), module_expr_(module_expr) {}
: primitive_(primitive), err_msg_(err_msg), module_expr_(module_expr) {}

/**
* \brief Returns a short error message corresponding to the kGeneral error
Expand All @@ -151,25 +46,11 @@ class IRScheduleErrorHandler {
*/
std::string DetailedErrorMessage() const;

/**
* \brief Returns a detailed error message corresponding to the kDetailed
* error level.
*/
std::string FormatErrorMessage(
const std::string& primitive,
const ScheduleErrorMessageLevel& err_msg_level) const;

private:
ModuleExpr module_expr_;
std::string primitive_;
std::string err_msg_;
ModuleExpr module_expr_;
};

std::string NegativeFactorErrorMessage(const int64_t& factor,
const size_t& idx);

std::string InferFactorErrorMessage();

std::string FactorProductErrorMessage();

} // namespace ir
} // namespace cinn
27 changes: 22 additions & 5 deletions paddle/cinn/ir/schedule/ir_schedule_util.cc
Original file line number Diff line number Diff line change
Expand Up @@ -222,6 +222,7 @@ void ReplaceExpr(Expr* source,
std::vector<int> ValidateFactors(const std::vector<int>& factors,
int total_extent,
const ModuleExpr& module_expr) {
const std::string primitive = "split";
CHECK(!factors.empty())
<< "The factors param of Split should not be empty! Please check.";
bool has_minus_one = false;
Expand All @@ -230,11 +231,19 @@ std::vector<int> ValidateFactors(const std::vector<int>& factors,
for (auto& i : factors) {
idx++;
if (i == 0 || i < -1) {
throw IRScheduleErrorHandler(NegativeFactorErrorMessage(i, idx),
module_expr);
std::ostringstream os;
os << "The params in factors of Split should be positive. However, the "
"factor at position "
<< idx << " is " << i << std::endl;
throw IRScheduleErrorHandler(primitive, os.str(), module_expr);
} else if (i == -1) {
if (has_minus_one) {
throw IRScheduleErrorHandler(InferFactorErrorMessage(), module_expr);
std::ostringstream os;
os << "The params in factors of Split should not be less than -1 or "
"have "
"more than one -1!"
<< std::endl;
throw IRScheduleErrorHandler(primitive, os.str(), module_expr);
}
has_minus_one = true;
} else {
Expand All @@ -244,12 +253,20 @@ std::vector<int> ValidateFactors(const std::vector<int>& factors,
std::vector<int> validated_factors = factors;
if (!has_minus_one) {
if (product < total_extent) {
throw IRScheduleErrorHandler(FactorProductErrorMessage(), module_expr);
std::ostringstream os;
os << "In Split, the factors' product should be not larger than or equal "
"to original loop's extent!"
<< std::endl;
throw IRScheduleErrorHandler(primitive, os.str(), module_expr);
}
return validated_factors;
} else {
if (product > total_extent) {
throw IRScheduleErrorHandler(FactorProductErrorMessage(), module_expr);
std::ostringstream os;
os << "In Split, the factors' product should be not larger than or equal "
"to original loop's extent!"
<< std::endl;
throw IRScheduleErrorHandler(primitive, os.str(), module_expr);
}
int minus_one_candidate = static_cast<int>(
ceil(static_cast<double>(total_extent) / static_cast<double>(product)));
Expand Down
Loading