From 3afd2091c9f495af1b228726e323e7d2e6a0b561 Mon Sep 17 00:00:00 2001 From: Naren Dasan Date: Fri, 24 Apr 2020 12:42:46 -0700 Subject: [PATCH] fix(//core/conversion): Check for calibrator before setting int8 mode Signed-off-by: Naren Dasan Signed-off-by: Naren Dasan --- core/conversion/conversionctx/ConversionCtx.cpp | 2 +- core/util/logging/TRTorchLogger.cpp | 16 ++++++++-------- 2 files changed, 9 insertions(+), 9 deletions(-) diff --git a/core/conversion/conversionctx/ConversionCtx.cpp b/core/conversion/conversionctx/ConversionCtx.cpp index 2d2e321a83..a3a5ddfc01 100644 --- a/core/conversion/conversionctx/ConversionCtx.cpp +++ b/core/conversion/conversionctx/ConversionCtx.cpp @@ -52,7 +52,7 @@ ConversionCtx::ConversionCtx(BuilderSettings build_settings) TRTORCH_CHECK(builder->platformHasFastInt8(), "Requested inference in INT8 but platform does support INT8"); cfg->setFlag(nvinfer1::BuilderFlag::kINT8); input_type = nvinfer1::DataType::kFLOAT; - // If the calibrator is nullptr then TRT will use default quantization + TRTORCH_CHECK(settings.calibrator != nullptr, "Requested inference in INT8 but no calibrator provided, set the ptq_calibrator field in the ExtraInfo struct with your calibrator"); cfg->setInt8Calibrator(settings.calibrator); break; case nvinfer1::DataType::kFLOAT: diff --git a/core/util/logging/TRTorchLogger.cpp b/core/util/logging/TRTorchLogger.cpp index a90ebed82c..d3968c9ee3 100644 --- a/core/util/logging/TRTorchLogger.cpp +++ b/core/util/logging/TRTorchLogger.cpp @@ -15,7 +15,7 @@ namespace trt = nvinfer1; namespace util { namespace logging { - + TRTorchLogger::TRTorchLogger(std::string prefix, Severity severity, bool color) : prefix_(prefix), reportable_severity_(severity), color_(color) {} @@ -32,7 +32,7 @@ void TRTorchLogger::log(Severity severity, const char* msg) { if (severity > reportable_severity_) { return; } - + if (color_) { switch (severity) { case Severity::kINTERNAL_ERROR: std::cerr << TERM_RED; break; @@ -41,9 +41,9 @@ void TRTorchLogger::log(Severity severity, const char* msg) { case Severity::kINFO: std::cerr << TERM_GREEN; break; case Severity::kVERBOSE: std::cerr << TERM_MAGENTA; break; default: break; - } + } } - + switch (severity) { case Severity::kINTERNAL_ERROR: std::cerr << "INTERNAL_ERROR: "; break; case Severity::kERROR: std::cerr << "ERROR: "; break; @@ -52,11 +52,11 @@ void TRTorchLogger::log(Severity severity, const char* msg) { case Severity::kVERBOSE: std::cerr << "DEBUG: "; break; default: std::cerr << "UNKNOWN: "; break; } - + if (color_) { std::cerr << TERM_NORMAL; } - + std::cerr << prefix_ << msg << std::endl; } @@ -92,7 +92,7 @@ bool TRTorchLogger::get_is_colored_output_on() { return color_; } - + namespace { TRTorchLogger& get_global_logger() { @@ -104,7 +104,7 @@ TRTorchLogger& get_global_logger() { static TRTorchLogger global_logger("[TRTorch] - ", LogLevel::kERROR, false); - #endif + #endif return global_logger; }