-
Notifications
You must be signed in to change notification settings - Fork 355
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
feat(//core/quantization): skeleton of INT8 PTQ calibrator
Signed-off-by: Naren Dasan <naren@narendasan.com> Signed-off-by: Naren Dasan <narens@nvidia.com>
- Loading branch information
1 parent
aef6003
commit dd443a6
Showing
10 changed files
with
221 additions
and
11 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -13,4 +13,6 @@ experiments/ | |
py/build/ | ||
py/tmp/ | ||
py/.eggs | ||
.vscode/ | ||
.vscode/ | ||
.DS_Store | ||
._DS_Store |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Empty file.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,64 @@ | ||
#include "core/util/prelude.h" | ||
#include "core/quantization/quantization.h" | ||
|
||
namespace trtorch { | ||
namespace core { | ||
namespace quantization { | ||
|
||
Int8CalibratorImpl::Int8CalibratorImpl(QuantizationSettings&& settings) | ||
: dataset_(std::move(settings.calibration_dataset), | ||
cache_file_path_(settings.calibration_cache_file), | ||
use_cache_(settings.use_cache) { | ||
buffers_.reserve(dataset_.size); | ||
|
||
} | ||
|
||
int Int8CalibratorImpl::GetBatchSize() const { | ||
|
||
} | ||
|
||
bool Int8CalibratorImpl::GetBatch(void* bindings[], const char* names[], int num_bindings) { | ||
if (!is_next_batch) { | ||
return false; | ||
} | ||
|
||
for (size_t i = 0; i < num_bindings; i++) { | ||
auto batch = next_binding_batch(names[i]); | ||
batch = batch.to(at::kCUDA).contiguous(); | ||
bindings[i] = batch.data_ptr(); | ||
} | ||
return true; | ||
} | ||
|
||
const void* Int8CalibratorImpl::ReadCalibrationCache(size_t& length) { | ||
cache_.clear(); | ||
std::ifstream cache_file(cache_file_path_, std::ios::binary); | ||
cache_file >> std::noskipws; | ||
if (use_cache && cache_file.good()) { | ||
std::copy(std::istream_iterator<char>(input), | ||
std::istream_iterator<char>(), | ||
std::back_inserter(cache_)); | ||
} | ||
cache_size_ = cache_.size(); | ||
return cache_size ? cache_.data() : nullptr; | ||
} | ||
|
||
void Int8CalibratorImpl::WriteCalibrationCache(const void* cache, size_t length) { | ||
std::ofstream cache_file(cache_file_path_, std::ios::binary); | ||
cache_file.write(reinterpret_cast<const char*>(cache_), cache_size_); | ||
} | ||
|
||
nvinfer1::IInt8Calibrator create_int8_calibrator(QuantizationSettings settings) { | ||
auto calibrator_impl = Int8CalibratorImpl(settings); | ||
switch(settings.calibrator_type) { | ||
case CalibratorKind::kMinMax: | ||
return TRTInt8MinMaxCalibrator(std::move(calibrator_impl)); | ||
case CalibratorKind::kEntropy: | ||
default: | ||
return TRTInt8EntropyCalibrator(std::move(calibrator_impl)); | ||
} | ||
} | ||
|
||
} // quantization | ||
} // core | ||
} // trtorch |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,69 @@ | ||
#pragma once | ||
#include "ATen/tensor.h" | ||
#include "NvInfer.h" | ||
|
||
namespace trtorch { | ||
namespace core { | ||
namespace quantization { | ||
|
||
enum class CalibratorKind { | ||
kEntropy, | ||
kMinMax, | ||
} | ||
|
||
in conveter or whatever | ||
in order given std::vector<at::Tensor> -> map<input_name, at::Tensor> | ||
|
||
struct QuantizationSettings { | ||
CalibratorKind calibrator_type = CalibratorKind::kEntropy; | ||
const std::string& calibration_cache_file = ""; | ||
bool use_cache = false; | ||
std::unordered_map<std::string, at::Tensor> calibration_dataset; | ||
}; | ||
|
||
class CalibrationBatchStream { | ||
|
||
}; | ||
|
||
class Int8CalibratorImpl { | ||
public: | ||
TRTInt8CalibratorImpl(QuantizationSettings& settings); | ||
int GetBatchSize() const; | ||
bool GetBatch(void* bindings[], const char* names[], int num_bindings); | ||
const void* ReadCalibrationCache(size_t& length); | ||
void WriteCalibrationCache(const void* cache, size_t length); | ||
private: | ||
std::unordered_map<std::string, at::Tensor> dataset_; | ||
const std::string& cache_file_path_; | ||
std::vector<char> cache_; | ||
bool use_cache_; | ||
size_t cache_size_ = 0; | ||
}; | ||
|
||
class TRTInt8EntropyCalibrator : nvinfer1::IInt8EntropyCalibrator2 { | ||
public: | ||
TRTInt8EntropyCalibrator(Int8CalibratorImpl impl) : impl_(impl) {} | ||
int getBatchSize() const override {return impl_.GetBatchSize();} | ||
bool getBatch(void* bindings[], const char* names[], int nbBindings) override {return impl_.GetBatch(bindings, names, nbBindings)}; | ||
const void* readCalibrationCache(size_t& length) override {return impl_.ReadCalibrationCache(size_t& length)}; | ||
void writeCalibrationCache(const void* cache, size_t length) override {impl_.WriteCalibrationCache(const void* cache, size_t length)}; | ||
private: | ||
Int8CalibratorImpl impl_; | ||
}; | ||
|
||
class TRTInt8MinMaxCalibrator : nvinfer1::IInt8MinMaxCalibrator { | ||
public: | ||
TRTInt8EntropyCalibrator(Int8CalibratorImpl impl) : impl_(impl) {} | ||
int getBatchSize() const override {return impl_.GetBatchSize();} | ||
bool getBatch(void* bindings[], const char* names[], int nbBindings) override {return impl_.GetBatch(bindings, names, nbBindings)}; | ||
const void* readCalibrationCache(size_t& length) override {return impl_.ReadCalibrationCache(size_t& length)}; | ||
void writeCalibrationCache(const void* cache, size_t length) override {impl_.WriteCalibrationCache(const void* cache, size_t length)}; | ||
private: | ||
Int8CalibratorImpl impl_; | ||
}; | ||
|
||
nvinfer1::IInt8Calibrator create_int8_calibrator(QuantizationSettings settings); | ||
|
||
} // quantization | ||
} // core | ||
} // trtorch |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,12 @@ | ||
package(default_visibility = ["//visibility:public"]) | ||
|
||
cc_binary( | ||
name = "ptq", | ||
srcs = [ | ||
"main.cpp" | ||
], | ||
deps = [ | ||
"@libtorch//:libtorch", | ||
"//cpp/api:trtorch" | ||
], | ||
) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,21 @@ | ||
# ptq | ||
|
||
This is a short example application that shows how to use TRTorch to perform post-training quantization for a module. | ||
|
||
## Compilation | ||
|
||
``` shell | ||
bazel build //cpp/ptq --cxxopt="-DNDEBUG" | ||
``` | ||
|
||
If you want insight into what is going under the hood or need debug symbols | ||
|
||
``` shell | ||
bazel build //cpp/ptq --compilation_mode=dbg | ||
``` | ||
|
||
## Usage | ||
|
||
``` shell | ||
ptq | ||
``` |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,36 @@ | ||
#include "torch/script.h" | ||
#include "torch/csrc/api/include/torch/data/datasets/mnist.h" | ||
#include "trtorch/trtorch.h" | ||
|
||
#include <iostream> | ||
#include <sstream> | ||
#include <memory> | ||
|
||
int main(int argc, const char* argv[]) { | ||
if (argc < 3) { | ||
std::cerr << "usage: ptq <path-to-module> <path-to-mnist>\n"; | ||
return -1; | ||
} | ||
|
||
torch::jit::script::Module mod; | ||
try { | ||
// Deserialize the ScriptModule from a file using torch::jit::load(). | ||
mod = torch::jit::load(argv[1]); | ||
} | ||
catch (const c10::Error& e) { | ||
std::cerr << "error loading the model\n"; | ||
return -1; | ||
} | ||
|
||
const std::string data_dir = std::string(argv[2]); | ||
auto calibration_dataset = torch::data::datasets::MNIST(data_dir, torch::data::datasets::MNIST::Mode::kTest) | ||
.map(torch::data::transforms::Normalize<>(0.1307, 0.3081)) | ||
.map(torch::data::transforms::Stack<>()); | ||
auto calibration_dataloader = torch::data::make_data_loader(std::move(calibration_dataset), torch::data::DataLoaderOptions() | ||
.batch_size(32) | ||
.workers(1)) | ||
|
||
for (auto batch : batched_calibration_dataset) { | ||
std::cout << batch.data().sizes() << std::endl; | ||
} | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters