From f3fe997e9e2ea36f9a3387ae38e92fbbdf2f4194 Mon Sep 17 00:00:00 2001 From: Ye Kuang Date: Thu, 15 Oct 2020 00:20:13 +0900 Subject: [PATCH] [refactor] Add format_error_message() method (#1955) --- taichi/program/program.cpp | 35 +++++++++-------------------------- taichi/util/str.cpp | 28 ++++++++++++++++++++++++++++ taichi/util/str.h | 4 ++++ 3 files changed, 41 insertions(+), 26 deletions(-) diff --git a/taichi/program/program.cpp b/taichi/program/program.cpp index c6bc59bfb3131..9bb3777fb6be2 100644 --- a/taichi/program/program.cpp +++ b/taichi/program/program.cpp @@ -23,6 +23,7 @@ #include "taichi/ir/frontend_ir.h" #include "taichi/program/async_engine.h" #include "taichi/util/statistics.h" +#include "taichi/util/str.h" #if defined(TI_WITH_CC) #include "taichi/backends/cc/struct_cc.h" #include "taichi/backends/cc/cc_layout.h" @@ -428,7 +429,7 @@ void Program::check_runtime_error() { // memory), use the device context instead. tlctx = llvm_context_device.get(); } - auto runtime_jit_module = tlctx->runtime_jit_module; + auto *runtime_jit_module = tlctx->runtime_jit_module; runtime_jit_module->call("runtime_retrieve_and_reset_error_code", llvm_runtime); auto error_code = fetch_result(taichi_result_buffer_error_id); @@ -451,31 +452,13 @@ void Program::check_runtime_error() { } if (error_code == 1) { - std::string error_message_formatted; - int argument_id = 0; - for (int i = 0; i < (int)error_message_template.size(); i++) { - if (error_message_template[i] != '%') { - error_message_formatted += error_message_template[i]; - } else { - auto dtype = error_message_template[i + 1]; - runtime_jit_module->call( - "runtime_retrieve_error_message_argument", llvm_runtime, - argument_id); - auto argument = fetch_result(taichi_result_buffer_error_id); - if (dtype == 'd') { - error_message_formatted += fmt::format( - "{}", taichi_union_cast_with_different_sizes(argument)); - } else if (dtype == 'f') { - error_message_formatted += fmt::format( - "{}", - taichi_union_cast_with_different_sizes(argument)); - } else { - TI_ERROR("Data type identifier %{} is not supported", dtype); - } - argument_id += 1; - i++; // skip the dtype char - } - } + const auto error_message_formatted = format_error_message( + error_message_template, [runtime_jit_module, this](int argument_id) { + runtime_jit_module->call( + "runtime_retrieve_error_message_argument", llvm_runtime, + argument_id); + return fetch_result(taichi_result_buffer_error_id); + }); TI_ERROR("Assertion failure: {}", error_message_formatted); } else { TI_NOT_IMPLEMENTED diff --git a/taichi/util/str.cpp b/taichi/util/str.cpp index d2d9366415d8a..469bcf255dadd 100644 --- a/taichi/util/str.cpp +++ b/taichi/util/str.cpp @@ -2,6 +2,8 @@ #include +#include "taichi/inc/constants.h" + TLANG_NAMESPACE_BEGIN std::string c_quoted(std::string const &str) { @@ -32,4 +34,30 @@ std::string c_quoted(std::string const &str) { return ss.str(); } +std::string format_error_message(const std::string &error_message_template, + const std::function &fetcher) { + std::string error_message_formatted; + int argument_id = 0; + for (int i = 0; i < (int)error_message_template.size(); i++) { + if (error_message_template[i] != '%') { + error_message_formatted += error_message_template[i]; + } else { + const auto dtype = error_message_template[i + 1]; + const auto argument = fetcher(argument_id); + if (dtype == 'd') { + error_message_formatted += fmt::format( + "{}", taichi_union_cast_with_different_sizes(argument)); + } else if (dtype == 'f') { + error_message_formatted += fmt::format( + "{}", taichi_union_cast_with_different_sizes(argument)); + } else { + TI_ERROR("Data type identifier %{} is not supported", dtype); + } + argument_id += 1; + i++; // skip the dtype char + } + } + return error_message_formatted; +} + TLANG_NAMESPACE_END diff --git a/taichi/util/str.h b/taichi/util/str.h index 20f1cb7ec4668..a0021058f1f1e 100644 --- a/taichi/util/str.h +++ b/taichi/util/str.h @@ -1,6 +1,7 @@ #pragma once #include +#include #include "taichi/lang_util.h" @@ -9,4 +10,7 @@ TLANG_NAMESPACE_BEGIN // Quote |str| with a pair of ". Escape special characters like \n, \t etc. std::string c_quoted(std::string const &str); +std::string format_error_message(const std::string &error_message_template, + const std::function &fetcher); + TLANG_NAMESPACE_END