Skip to content

Commit

Permalink
refactor(//cpp/api/ptq): Move from direct use of dataloader to a
Browse files Browse the repository at this point in the history
buffered version to improve stability

Signed-off-by: Naren Dasan <naren@narendasan.com>
Signed-off-by: Naren Dasan <narens@nvidia.com>
  • Loading branch information
narendasan committed May 28, 2020
1 parent e6f598f commit 4741246
Show file tree
Hide file tree
Showing 3 changed files with 43 additions and 28 deletions.
1 change: 1 addition & 0 deletions cpp/api/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ cc_library(
"src/extra_info.cpp",
"src/logging.cpp",
"src/trtorch.cpp",
"src/ptq.cpp"
],
deps = [
"//core",
Expand Down
54 changes: 26 additions & 28 deletions cpp/api/include/trtorch/ptq.h
Original file line number Diff line number Diff line change
Expand Up @@ -6,16 +6,21 @@
#include <iostream>
#include <sstream>

#include "trtorch/logging.h"

#ifndef DOXYGEN_SHOULD_SKIP_THIS
namespace nvinfer1 {
class IInt8Calibrator;
class IInt8EntropyCalibrator2;
}

namespace torch {
namespace data {
template<typename Example>
class Iterator;
class Tensor;
}

namespace trtorch {
namespace ptq {
bool get_batch_impl(void* bindings[], const char* names[], int nbBindings, torch::Tensor& data);
}
}
#endif //DOXYGEN_SHOULD_SKIP_THIS
Expand Down Expand Up @@ -45,7 +50,12 @@ class Int8Calibrator : Algorithm {
* @param use_cache : bool - Whether to use the cache (if it exists)
*/
Int8Calibrator(DataLoaderUniquePtr dataloader, const std::string& cache_file_path, bool use_cache)
: dataloader_(dataloader.get()), it_(dataloader_->end()), cache_file_path_(cache_file_path), use_cache_(use_cache) {}
: dataloader_(dataloader.get()), cache_file_path_(cache_file_path), use_cache_(use_cache) {
for (auto batch : *dataloader_) {
batched_data_.push_back(batch.data);
}
it_ = batched_data_.begin();
}

/**
* @brief Get the Batch Size for the next batch (always 1 due to issues with TRT and explicit batch)
Expand All @@ -70,26 +80,15 @@ class Int8Calibrator : Algorithm {
* @return false - There is not a new batch for the calibrator to consume
*/
bool getBatch(void* bindings[], const char* names[], int nbBindings) override {
// HACK: doesnt seem like the first try in the initializer list works
if (! it_created_) {
it_ = dataloader_->begin();
it_created_ = true;
}

if (it_ == dataloader_->end()) {
if (it_ != batched_data_.end()) {
auto status = get_batch_impl(bindings, names, nbBindings, *it_);
it_ = ++it_;
return status;
} else {
// Reset iterator if incase calibrator is going to be used again
it_ = batched_data_.begin();
return false;
}

auto batch = *it_;

for (int i = 0; i < nbBindings; i++) {
auto data = batch.data;
data = data.to(at::kCUDA).contiguous();
bindings[i] = data.data_ptr();
}

it_ = ++it_;
return true;
}

/**
Expand Down Expand Up @@ -151,8 +150,6 @@ class Int8Calibrator : Algorithm {
private:
/// Pointer to the dataloader
DataLoader* dataloader_;
/// Iterator used to traverse the dataloader
torch::data::Iterator<Batch> it_;
/// Path to cache file
const std::string& cache_file_path_;
/// Size of cache
Expand All @@ -161,10 +158,11 @@ class Int8Calibrator : Algorithm {
bool use_cache_;
/// Cache data
std::vector<char> cache_;
/// If the iterator has been created, DataLoaders can only have 1 live iterator,
/// due to some issues this cannot be created at construction, so it is set in the first
/// batch, controlled by this flag
bool it_created_ = false;
/// Batched Data
std::vector<torch::Tensor> batched_data_;
/// Iterator to move through dataset
std::vector<torch::Tensor>::iterator it_;

};

/**
Expand Down
16 changes: 16 additions & 0 deletions cpp/api/src/ptq.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
#include "torch/torch.h"
#include "trtorch/ptq.h"

namespace trtorch {
namespace ptq {

bool get_batch_impl(void* bindings[], const char* names[], int nbBindings, torch::Tensor& data) {
for (int i = 0; i < nbBindings; i++) {
data = data.to(at::kCUDA).contiguous();
bindings[i] = data.data_ptr();
}
return true;
}

} // namespace ptq
} // namespace trtorch

0 comments on commit 4741246

Please sign in to comment.