diff --git a/cpp/ptq/main.cpp b/cpp/ptq/main.cpp index 5381bc408a..ed40398826 100644 --- a/cpp/ptq/main.cpp +++ b/cpp/ptq/main.cpp @@ -35,7 +35,6 @@ struct Resize : public torch::data::transforms::TensorTransform { torch::jit::Module compile_int8_model(const std::string& data_dir, torch::jit::Module& mod) { auto calibration_dataset = datasets::CIFAR10(data_dir, datasets::CIFAR10::Mode::kTest) .use_subset(320) - .map(Resize({300, 300})) .map(torch::data::transforms::Normalize<>({0.4914, 0.4822, 0.4465}, {0.2023, 0.1994, 0.2010})) .map(torch::data::transforms::Stack<>()); @@ -48,7 +47,7 @@ torch::jit::Module compile_int8_model(const std::string& data_dir, torch::jit::M auto calibrator = trtorch::ptq::make_int8_calibrator(std::move(calibration_dataloader), calibration_cache_file, true); - std::vector> input_shape = {{32, 3, 300, 300}}; + std::vector> input_shape = {{32, 3, 32, 32}}; /// Configure settings for compilation auto extra_info = trtorch::ExtraInfo({input_shape}); /// Set operating precision to INT8 @@ -99,7 +98,6 @@ int main(int argc, const char* argv[]) { /// Dataloader moved into calibrator so need another for inference auto eval_dataset = datasets::CIFAR10(data_dir, datasets::CIFAR10::Mode::kTest) - .map(Resize({300, 300})) .map(torch::data::transforms::Normalize<>({0.4914, 0.4822, 0.4465}, {0.2023, 0.1994, 0.2010})) .map(torch::data::transforms::Stack<>()); @@ -131,7 +129,7 @@ int main(int argc, const char* argv[]) { if (images.sizes()[0] < 32) { /// To handle smaller batches util Optimization profiles work with Int8 auto diff = 32 - images.sizes()[0]; - auto img_padding = torch::zeros({diff, 3, 300, 300}, {torch::kCUDA}); + auto img_padding = torch::zeros({diff, 3, 32, 32}, {torch::kCUDA}); auto target_padding = torch::zeros({diff}, {torch::kCUDA}); images = torch::cat({images, img_padding}, 0); targets = torch::cat({targets, target_padding}, 0); @@ -152,7 +150,7 @@ int main(int argc, const char* argv[]) { std::cout << "Accuracy of quantized model on test set: " << 100 * (correct / total) << "%" << std::endl; /// Time execution in JIT-FP32 and TRT-INT8 - std::vector> dims = {{32, 3, 300, 300}}; + std::vector> dims = {{32, 3, 32, 32}}; auto jit_runtimes = benchmark_module(mod, dims[0]); print_avg_std_dev("JIT model FP32", jit_runtimes, dims[0][0]);