diff --git a/.gitignore b/.gitignore index 05079c76df..e82fdfd9b1 100644 --- a/.gitignore +++ b/.gitignore @@ -13,4 +13,6 @@ experiments/ py/build/ py/tmp/ py/.eggs -.vscode/ \ No newline at end of file +.vscode/ +.DS_Store +._DS_Store \ No newline at end of file diff --git a/core/conversion/conversionctx/ConversionCtx.cpp b/core/conversion/conversionctx/ConversionCtx.cpp index 06f6317012..81a4ad6f8e 100644 --- a/core/conversion/conversionctx/ConversionCtx.cpp +++ b/core/conversion/conversionctx/ConversionCtx.cpp @@ -19,7 +19,8 @@ std::ostream& operator<<(std::ostream& os, const BuilderSettings& s) { << "\n Avg Timing Iterations: " << s.num_avg_timing_iters \ << "\n Max Workspace Size: " << s.workspace_size \ << "\n Device Type: " << s.device \ - << "\n Engine Capability: " << s.capability; + << "\n Engine Capability: " << s.capability \ + << "\n Calibrator Created: " << s.calibrator ? true : false; return os; } @@ -36,13 +37,17 @@ ConversionCtx::ConversionCtx(BuilderSettings build_settings) switch(settings.op_precision) { case nvinfer1::DataType::kHALF: + TRTORCH_CHECK(builder->platformHasFastFp16(), "Requested inference in FP16 but platform does support FP16"); cfg->setFlag(nvinfer1::BuilderFlag::kFP16); input_type = nvinfer1::DataType::kHALF; break; - // case nvinfer1::DataType::kINT8: - // cfg->setFlag(nvinfer1::BuilderFlag::kINT8); - // input_type = nvinfer1::DataType::kFLOAT; - // break; + case nvinfer1::DataType::kINT8: + TRTORCH_CHECK(builder->platformHasFastInt8(), "Requested inference in INT8 but platform does support INT8"); + cfg->setFlag(nvinfer1::BuilderFlag::kINT8); + input_type = nvinfer1::DataType::kINT8; + // If the calibrator is nullptr then TRT will use default quantization + cfg->setInt8Calibrator(settings.calibrator); + break; case nvinfer1::DataType::kFLOAT: default: input_type = nvinfer1::DataType::kFLOAT; diff --git a/core/conversion/conversionctx/ConversionCtx.h b/core/conversion/conversionctx/ConversionCtx.h index 06f3755490..d8adef5ff0 100644 --- a/core/conversion/conversionctx/ConversionCtx.h +++ b/core/conversion/conversionctx/ConversionCtx.h @@ -24,6 +24,7 @@ struct BuilderSettings { bool allow_gpu_fallback = true; nvinfer1::DeviceType device = nvinfer1::DeviceType::kGPU; nvinfer1::EngineCapability capability = nvinfer1::EngineCapability::kDEFAULT; + nvinfer1::IInt8Calibrator* calibrator = nullptr; uint64_t num_min_timing_iters = 2; uint64_t num_avg_timing_iters = 1; uint64_t workspace_size = 0; diff --git a/core/quantization/BUILD b/core/quantization/BUILD new file mode 100644 index 0000000000..e69de29bb2 diff --git a/core/quantization/TRTEntropyCalibrator.cpp b/core/quantization/TRTEntropyCalibrator.cpp new file mode 100644 index 0000000000..ebc9188da3 --- /dev/null +++ b/core/quantization/TRTEntropyCalibrator.cpp @@ -0,0 +1,64 @@ +#include "core/util/prelude.h" +#include "core/quantization/quantization.h" + +namespace trtorch { +namespace core { +namespace quantization { + +Int8CalibratorImpl::Int8CalibratorImpl(QuantizationSettings&& settings) + : dataset_(std::move(settings.calibration_dataset), + cache_file_path_(settings.calibration_cache_file), + use_cache_(settings.use_cache) { + buffers_.reserve(dataset_.size); + +} + +int Int8CalibratorImpl::GetBatchSize() const { + +} + +bool Int8CalibratorImpl::GetBatch(void* bindings[], const char* names[], int num_bindings) { + if (!is_next_batch) { + return false; + } + + for (size_t i = 0; i < num_bindings; i++) { + auto batch = next_binding_batch(names[i]); + batch = batch.to(at::kCUDA).contiguous(); + bindings[i] = batch.data_ptr(); + } + return true; +} + +const void* Int8CalibratorImpl::ReadCalibrationCache(size_t& length) { + cache_.clear(); + std::ifstream cache_file(cache_file_path_, std::ios::binary); + cache_file >> std::noskipws; + if (use_cache && cache_file.good()) { + std::copy(std::istream_iterator(input), + std::istream_iterator(), + std::back_inserter(cache_)); + } + cache_size_ = cache_.size(); + return cache_size ? cache_.data() : nullptr; +} + +void Int8CalibratorImpl::WriteCalibrationCache(const void* cache, size_t length) { + std::ofstream cache_file(cache_file_path_, std::ios::binary); + cache_file.write(reinterpret_cast(cache_), cache_size_); +} + +nvinfer1::IInt8Calibrator create_int8_calibrator(QuantizationSettings settings) { + auto calibrator_impl = Int8CalibratorImpl(settings); + switch(settings.calibrator_type) { + case CalibratorKind::kMinMax: + return TRTInt8MinMaxCalibrator(std::move(calibrator_impl)); + case CalibratorKind::kEntropy: + default: + return TRTInt8EntropyCalibrator(std::move(calibrator_impl)); + } +} + +} // quantization +} // core +} // trtorch diff --git a/core/quantization/quantization.h b/core/quantization/quantization.h new file mode 100644 index 0000000000..5a6c150923 --- /dev/null +++ b/core/quantization/quantization.h @@ -0,0 +1,69 @@ +#pragma once +#include "ATen/tensor.h" +#include "NvInfer.h" + +namespace trtorch { +namespace core { +namespace quantization { + +enum class CalibratorKind { + kEntropy, + kMinMax, +} + +in conveter or whatever +in order given std::vector -> map + +struct QuantizationSettings { + CalibratorKind calibrator_type = CalibratorKind::kEntropy; + const std::string& calibration_cache_file = ""; + bool use_cache = false; + std::unordered_map calibration_dataset; +}; + +class CalibrationBatchStream { + +}; + +class Int8CalibratorImpl { +public: + TRTInt8CalibratorImpl(QuantizationSettings& settings); + int GetBatchSize() const; + bool GetBatch(void* bindings[], const char* names[], int num_bindings); + const void* ReadCalibrationCache(size_t& length); + void WriteCalibrationCache(const void* cache, size_t length); +private: + std::unordered_map dataset_; + const std::string& cache_file_path_; + std::vector cache_; + bool use_cache_; + size_t cache_size_ = 0; +}; + +class TRTInt8EntropyCalibrator : nvinfer1::IInt8EntropyCalibrator2 { +public: + TRTInt8EntropyCalibrator(Int8CalibratorImpl impl) : impl_(impl) {} + int getBatchSize() const override {return impl_.GetBatchSize();} + bool getBatch(void* bindings[], const char* names[], int nbBindings) override {return impl_.GetBatch(bindings, names, nbBindings)}; + const void* readCalibrationCache(size_t& length) override {return impl_.ReadCalibrationCache(size_t& length)}; + void writeCalibrationCache(const void* cache, size_t length) override {impl_.WriteCalibrationCache(const void* cache, size_t length)}; +private: + Int8CalibratorImpl impl_; +}; + +class TRTInt8MinMaxCalibrator : nvinfer1::IInt8MinMaxCalibrator { +public: + TRTInt8EntropyCalibrator(Int8CalibratorImpl impl) : impl_(impl) {} + int getBatchSize() const override {return impl_.GetBatchSize();} + bool getBatch(void* bindings[], const char* names[], int nbBindings) override {return impl_.GetBatch(bindings, names, nbBindings)}; + const void* readCalibrationCache(size_t& length) override {return impl_.ReadCalibrationCache(size_t& length)}; + void writeCalibrationCache(const void* cache, size_t length) override {impl_.WriteCalibrationCache(const void* cache, size_t length)}; +private: + Int8CalibratorImpl impl_; +}; + +nvinfer1::IInt8Calibrator create_int8_calibrator(QuantizationSettings settings); + +} // quantization +} // core +} // trtorch \ No newline at end of file diff --git a/cpp/ptq/BUILD b/cpp/ptq/BUILD new file mode 100644 index 0000000000..f17db99cc0 --- /dev/null +++ b/cpp/ptq/BUILD @@ -0,0 +1,12 @@ +package(default_visibility = ["//visibility:public"]) + +cc_binary( + name = "ptq", + srcs = [ + "main.cpp" + ], + deps = [ + "@libtorch//:libtorch", + "//cpp/api:trtorch" + ], +) diff --git a/cpp/ptq/README.md b/cpp/ptq/README.md new file mode 100644 index 0000000000..ffec48eaaf --- /dev/null +++ b/cpp/ptq/README.md @@ -0,0 +1,21 @@ +# ptq + +This is a short example application that shows how to use TRTorch to perform post-training quantization for a module. + +## Compilation + +``` shell +bazel build //cpp/ptq --cxxopt="-DNDEBUG" +``` + +If you want insight into what is going under the hood or need debug symbols + +``` shell +bazel build //cpp/ptq --compilation_mode=dbg +``` + +## Usage + +``` shell +ptq +``` \ No newline at end of file diff --git a/cpp/ptq/main.cpp b/cpp/ptq/main.cpp new file mode 100644 index 0000000000..f7fe857c99 --- /dev/null +++ b/cpp/ptq/main.cpp @@ -0,0 +1,36 @@ +#include "torch/script.h" +#include "torch/csrc/api/include/torch/data/datasets/mnist.h" +#include "trtorch/trtorch.h" + +#include +#include +#include + +int main(int argc, const char* argv[]) { + if (argc < 3) { + std::cerr << "usage: ptq \n"; + return -1; + } + + torch::jit::script::Module mod; + try { + // Deserialize the ScriptModule from a file using torch::jit::load(). + mod = torch::jit::load(argv[1]); + } + catch (const c10::Error& e) { + std::cerr << "error loading the model\n"; + return -1; + } + + const std::string data_dir = std::string(argv[2]); + auto calibration_dataset = torch::data::datasets::MNIST(data_dir, torch::data::datasets::MNIST::Mode::kTest) + .map(torch::data::transforms::Normalize<>(0.1307, 0.3081)) + .map(torch::data::transforms::Stack<>()); + auto calibration_dataloader = torch::data::make_data_loader(std::move(calibration_dataset), torch::data::DataLoaderOptions() + .batch_size(32) + .workers(1)) + + for (auto batch : batched_calibration_dataset) { + std::cout << batch.data().sizes() << std::endl; + } +} diff --git a/cpp/trtorchexec/main.cpp b/cpp/trtorchexec/main.cpp index dec456e913..f3909fef27 100644 --- a/cpp/trtorchexec/main.cpp +++ b/cpp/trtorchexec/main.cpp @@ -12,7 +12,7 @@ bool checkRtol(const at::Tensor& diff, const std::vector inputs) { maxValue = fmax(tensor.abs().max().item(), maxValue); } std::cout << "Max Difference: " << diff.abs().max().item() << std::endl; - return diff.abs().max().item() <= 2e-6 * maxValue; + return diff.abs().max().item() <= 2e-5 * maxValue; } bool almostEqual(const at::Tensor& a, const at::Tensor& b) { @@ -25,8 +25,8 @@ int main(int argc, const char* argv[]) { << " trtorchexec \n"; return -1; } - - + + torch::jit::script::Module mod; try { // Deserialize the ScriptModule from a file using torch::jit::load(). @@ -38,7 +38,7 @@ int main(int argc, const char* argv[]) { } mod.to(at::kCUDA); - + std::vector> dims; for (int i = 2; i < argc; i++) { auto arg = std::string(argv[i]); @@ -74,7 +74,7 @@ int main(int argc, const char* argv[]) { torch::jit::IValue jit_results_ivalues = mod.forward(jit_inputs_ivalues); std::vector jit_results; jit_results.push_back(jit_results_ivalues.toTensor()); - + auto trt_mod = trtorch::CompileGraph(mod, dims); torch::jit::IValue trt_results_ivalues = trt_mod.forward(trt_inputs_ivalues); std::vector trt_results;