diff --git a/.gitignore b/.gitignore index 5c2e84e175..37dd6a63ea 100644 --- a/.gitignore +++ b/.gitignore @@ -18,4 +18,8 @@ py/.eggs ._DS_Store *.pth *.pyc -cpp/ptq/training/vgg16/data/ \ No newline at end of file +cpp/ptq/training/vgg16/data/* +*.bin +cpp/ptq/datasets/data/ +._.DS_Store +*.tar.gz diff --git a/core/BUILD b/core/BUILD index cd3fe3f21a..be3937aa64 100644 --- a/core/BUILD +++ b/core/BUILD @@ -16,7 +16,7 @@ cc_library( "@libtorch//:libtorch", "@tensorrt//:nvinfer" ], - alwayslink=True, + alwayslink=True, ) diff --git a/core/compiler.cpp b/core/compiler.cpp index 33e2f04bff..9e3a69033b 100644 --- a/core/compiler.cpp +++ b/core/compiler.cpp @@ -24,24 +24,24 @@ namespace trtorch { namespace core { -c10::FunctionSchema GenerateGraphSchema(torch::jit::script::Module mod, std::string method_name, std::shared_ptr& g) { +c10::FunctionSchema GenerateGraphSchema(torch::jit::script::Module mod, std::string method_name, std::shared_ptr& g) { std::vector args; for (auto in : g->inputs()) { args.push_back(c10::Argument(in->debugName(), in->type())); } - + std::vector returns; for (auto out : g->outputs()) { returns.push_back(c10::Argument(out->debugName(), out->type())); } - + return c10::FunctionSchema(method_name, method_name, args, returns); } void AddEngineToGraph(torch::jit::script::Module mod, std::shared_ptr& g, std::string& serialized_engine) { - execution::EngineID uid = execution::RegisterEngineFromSerializedEngine(serialized_engine); + execution::EngineID uid = execution::RegisterEngineFromSerializedEngine(serialized_engine); auto schema = execution::GetEngineFunctionSchema(uid); auto num_io = execution::GetEngineIO(uid); @@ -53,14 +53,14 @@ void AddEngineToGraph(torch::jit::script::Module mod, std::shared_ptrsetType(c10::TensorType::get()); graph_inputs.push_back(in_val); } - + auto engine_node = g->create(c10::Symbol::fromQualString(schema.name()), torch::jit::ArrayRef(graph_inputs), num_io.second); g->block()->appendNode(engine_node); for (auto o : engine_node->outputs()) { g->registerOutput(o); } - + return; } @@ -69,48 +69,50 @@ bool CheckMethodOperatorSupport(const torch::jit::script::Module& mod, auto g = mod.get_method(method_name).graph(); // Go through PyTorch Lowering to simplify graph and extract weight parameters auto graph_and_parameters = torch::jit::LowerGraph(*g, mod._ivalue()); - + g = graph_and_parameters.first; - + // Go through TRTorch Lowering to reformat graph to be conversion friendly // and also segment for accelerators and executors (TRT-DLA, TRT-GPU, PYT) lowering::LowerGraph(g); - + auto params = graph_and_parameters.second; auto named_params = conversion::get_named_params(g->inputs(), params); LOG_DEBUG(*g << "(CheckMethodOperatorSupport)\n"); - + // Is this necessary? lowering::LowerBlock(g->block()); - + return conversion::VerifyConverterSupportForBlock(g->block()); } std::string ConvertGraphToTRTEngine(const torch::jit::script::Module& mod, std::string method_name, - conversion::ExtraInfo cfg) { + ExtraInfo cfg) { + auto convert_cfg = std::move(cfg.convert_info); + auto g = mod.get_method(method_name).graph(); // Go through PyTorch Lowering to simplify graph and extract weight parameters auto graph_and_parameters = torch::jit::LowerGraph(*g, mod._ivalue()); - + g = graph_and_parameters.first; - + // Go through TRTorch Lowering to reformat graph to be conversion friendly // and also segment for accelerators and executors (TRT-DLA, TRT-GPU, PYT) lowering::LowerGraph(g); - + auto params = graph_and_parameters.second; auto named_params = conversion::get_named_params(g->inputs(), params); LOG_INFO(*g << "(CompileGraph)\n"); - + // Is this necessary? lowering::LowerBlock(g->block()); - auto engine = ConvertBlockToEngine(g->block(), cfg, named_params); + auto engine = ConvertBlockToEngine(g->block(), convert_cfg, named_params); return std::move(engine); } torch::jit::script::Module CompileGraph(const torch::jit::script::Module& mod, - conversion::ExtraInfo cfg) { + ExtraInfo cfg) { // TODO: Should be doing a functional transform but need PR #31978 // [jit] More robust mangling // torch::jit::script::Module new_mod = mod.clone(); @@ -128,7 +130,7 @@ torch::jit::script::Module CompileGraph(const torch::jit::script::Module& mod, return new_mod; } - + } // namespace core } // namespace trtorch diff --git a/core/compiler.h b/core/compiler.h index 17ab1719db..a8b248d3a2 100644 --- a/core/compiler.h +++ b/core/compiler.h @@ -6,12 +6,19 @@ namespace trtorch { namespace core { + +struct ExtraInfo { + ExtraInfo(std::vector input_ranges) + : convert_info(std::move(input_ranges)) {} + conversion::ConversionInfo convert_info; +}; + bool CheckMethodOperatorSupport(const torch::jit::script::Module& mod, std::string method_name); std::string ConvertGraphToTRTEngine(const torch::jit::script::Module& mod, - std::string method_name, conversion::ExtraInfo cfg); + std::string method_name, ExtraInfo cfg); -torch::jit::script::Module CompileGraph(const torch::jit::script::Module& module, conversion::ExtraInfo cfg); +torch::jit::script::Module CompileGraph(const torch::jit::script::Module& module, ExtraInfo cfg); } // namespace core -} // namespace trtorch +} // namespace trtorch diff --git a/core/conversion/conversion.cpp b/core/conversion/conversion.cpp index 74bf320e1e..d71af6dbdc 100644 --- a/core/conversion/conversion.cpp +++ b/core/conversion/conversion.cpp @@ -179,7 +179,7 @@ void AddParamsToCtxValueMap(ConversionCtx* ctx, GraphParams& params) { } } -void ConvertBlockToNetDef(ConversionCtx* ctx, const torch::jit::Block* b, ExtraInfo build_info, GraphParams& static_params) { +void ConvertBlockToNetDef(ConversionCtx* ctx, const torch::jit::Block* b, ConversionInfo build_info, GraphParams& static_params) { LOG_INFO(ctx->logger, "Converting Block"); auto inputs = b->inputs(); @@ -221,7 +221,7 @@ void ConvertBlockToNetDef(ConversionCtx* ctx, const torch::jit::Block* b, ExtraI // a serialized TensorRT engine that can be deserialized and run // Probably should consolidate these two functions -std::string ConvertBlockToEngine(const torch::jit::Block* b, ExtraInfo build_info, GraphParams& static_params) { +std::string ConvertBlockToEngine(const torch::jit::Block* b, ConversionInfo build_info, GraphParams& static_params) { ConversionCtx ctx(build_info.engine_settings); ConvertBlockToNetDef(&ctx, b, build_info, static_params); std::string engine = ctx.SerializeEngine(); diff --git a/core/conversion/conversion.h b/core/conversion/conversion.h index f053eced84..c7a50a6319 100644 --- a/core/conversion/conversion.h +++ b/core/conversion/conversion.h @@ -30,10 +30,10 @@ struct InputRange { std::vector max_shape); }; -struct ExtraInfo { +struct ConversionInfo { std::vector input_ranges; BuilderSettings engine_settings; - ExtraInfo(std::vector input_ranges) + ConversionInfo(std::vector input_ranges) : input_ranges(std::move(input_ranges)), engine_settings(BuilderSettings()) {} }; @@ -43,7 +43,7 @@ GraphParams get_named_params(c10::ArrayRef inputs, std::vect // Converts a already lowered block (blocks with no sub blocks) to // a serialized TensorRT engine that can be deserialized and run -std::string ConvertBlockToEngine(const torch::jit::Block* b, ExtraInfo build_info, GraphParams& static_params); +std::string ConvertBlockToEngine(const torch::jit::Block* b, ConversionInfo build_info, GraphParams& static_params); bool OpSupported(const torch::jit::Node* n); diff --git a/core/conversion/conversionctx/ConversionCtx.cpp b/core/conversion/conversionctx/ConversionCtx.cpp index 81a4ad6f8e..7348dfe6c7 100644 --- a/core/conversion/conversionctx/ConversionCtx.cpp +++ b/core/conversion/conversionctx/ConversionCtx.cpp @@ -20,7 +20,7 @@ std::ostream& operator<<(std::ostream& os, const BuilderSettings& s) { << "\n Max Workspace Size: " << s.workspace_size \ << "\n Device Type: " << s.device \ << "\n Engine Capability: " << s.capability \ - << "\n Calibrator Created: " << s.calibrator ? true : false; + << "\n Calibrator Created: " << (s.calibrator != nullptr); return os; } diff --git a/core/conversion/converters/impl/batch_norm.cpp b/core/conversion/converters/impl/batch_norm.cpp index f1ee48d59b..f4fcdc9e29 100644 --- a/core/conversion/converters/impl/batch_norm.cpp +++ b/core/conversion/converters/impl/batch_norm.cpp @@ -83,7 +83,8 @@ volatile auto batch_norm_registrations = RegisterNodeConversionPatterns() auto gamma = args[1].unwrapToTensor(); if (/*training*/ args[5].unwrapToBool()) { - LOG_WARNING("TensorRT only converts forward pass of graphs, but saw training = True, may see undefined behavior, consider placing module in eval mode"); + LOG_WARNING(R"WARN(TRTorch only converts forward pass of graphs, but saw training = True, may see + unexpected behavior, consider placing module in eval mode before exporting the TorchScript module)WARN"); } // If gamma is None this fails diff --git a/core/conversion/converters/impl/pooling.cpp b/core/conversion/converters/impl/pooling.cpp index 85a4acf5bf..04472ce5fc 100644 --- a/core/conversion/converters/impl/pooling.cpp +++ b/core/conversion/converters/impl/pooling.cpp @@ -79,20 +79,17 @@ auto pooling_registrations = RegisterNodeConversionPatterns() for (size_t i = 0; i < out_shape.size(); i++) { stride[(stride.size() - 1) - i] = in_shape[(in_shape.size() - 1) - i] / out_shape[(out_shape.size() - 1) - i]; } - LOG_DEBUG("Stride" << util::toDims(stride)); + LOG_DEBUG("Stride: " << util::toDims(stride)); std::vector window(out_shape.size()); for (size_t i = 0; i < out_shape.size(); i++) { window[window.size() - 1 - i] = in_shape[in_shape.size() - 1 - i] - (out_shape[out_shape.size() - 1 - i] - 1) * stride[stride.size() - 1 - i]; } - LOG_DEBUG("Window" << util::toDims(window)); + LOG_DEBUG("Window: " << util::toDims(window)); auto new_layer = ctx->net->addPoolingNd(*in, nvinfer1::PoolingType::kAVERAGE, util::toDims(window)); - if (!new_layer) { - LOG_ERROR("Unable to create average pooling layer from node: " << *n); - return false; - } + TRTORCH_CHECK(new_layer, "Unable to create average pooling layer from node: " << *n); new_layer->setStrideNd(util::toDims(stride)); diff --git a/core/quantization/BUILD b/core/quantization/BUILD deleted file mode 100644 index e69de29bb2..0000000000 diff --git a/core/quantization/TRTEntropyCalibrator.cpp b/core/quantization/TRTEntropyCalibrator.cpp deleted file mode 100644 index ebc9188da3..0000000000 --- a/core/quantization/TRTEntropyCalibrator.cpp +++ /dev/null @@ -1,64 +0,0 @@ -#include "core/util/prelude.h" -#include "core/quantization/quantization.h" - -namespace trtorch { -namespace core { -namespace quantization { - -Int8CalibratorImpl::Int8CalibratorImpl(QuantizationSettings&& settings) - : dataset_(std::move(settings.calibration_dataset), - cache_file_path_(settings.calibration_cache_file), - use_cache_(settings.use_cache) { - buffers_.reserve(dataset_.size); - -} - -int Int8CalibratorImpl::GetBatchSize() const { - -} - -bool Int8CalibratorImpl::GetBatch(void* bindings[], const char* names[], int num_bindings) { - if (!is_next_batch) { - return false; - } - - for (size_t i = 0; i < num_bindings; i++) { - auto batch = next_binding_batch(names[i]); - batch = batch.to(at::kCUDA).contiguous(); - bindings[i] = batch.data_ptr(); - } - return true; -} - -const void* Int8CalibratorImpl::ReadCalibrationCache(size_t& length) { - cache_.clear(); - std::ifstream cache_file(cache_file_path_, std::ios::binary); - cache_file >> std::noskipws; - if (use_cache && cache_file.good()) { - std::copy(std::istream_iterator(input), - std::istream_iterator(), - std::back_inserter(cache_)); - } - cache_size_ = cache_.size(); - return cache_size ? cache_.data() : nullptr; -} - -void Int8CalibratorImpl::WriteCalibrationCache(const void* cache, size_t length) { - std::ofstream cache_file(cache_file_path_, std::ios::binary); - cache_file.write(reinterpret_cast(cache_), cache_size_); -} - -nvinfer1::IInt8Calibrator create_int8_calibrator(QuantizationSettings settings) { - auto calibrator_impl = Int8CalibratorImpl(settings); - switch(settings.calibrator_type) { - case CalibratorKind::kMinMax: - return TRTInt8MinMaxCalibrator(std::move(calibrator_impl)); - case CalibratorKind::kEntropy: - default: - return TRTInt8EntropyCalibrator(std::move(calibrator_impl)); - } -} - -} // quantization -} // core -} // trtorch diff --git a/core/quantization/quantization.h b/core/quantization/quantization.h deleted file mode 100644 index 5a6c150923..0000000000 --- a/core/quantization/quantization.h +++ /dev/null @@ -1,69 +0,0 @@ -#pragma once -#include "ATen/tensor.h" -#include "NvInfer.h" - -namespace trtorch { -namespace core { -namespace quantization { - -enum class CalibratorKind { - kEntropy, - kMinMax, -} - -in conveter or whatever -in order given std::vector -> map - -struct QuantizationSettings { - CalibratorKind calibrator_type = CalibratorKind::kEntropy; - const std::string& calibration_cache_file = ""; - bool use_cache = false; - std::unordered_map calibration_dataset; -}; - -class CalibrationBatchStream { - -}; - -class Int8CalibratorImpl { -public: - TRTInt8CalibratorImpl(QuantizationSettings& settings); - int GetBatchSize() const; - bool GetBatch(void* bindings[], const char* names[], int num_bindings); - const void* ReadCalibrationCache(size_t& length); - void WriteCalibrationCache(const void* cache, size_t length); -private: - std::unordered_map dataset_; - const std::string& cache_file_path_; - std::vector cache_; - bool use_cache_; - size_t cache_size_ = 0; -}; - -class TRTInt8EntropyCalibrator : nvinfer1::IInt8EntropyCalibrator2 { -public: - TRTInt8EntropyCalibrator(Int8CalibratorImpl impl) : impl_(impl) {} - int getBatchSize() const override {return impl_.GetBatchSize();} - bool getBatch(void* bindings[], const char* names[], int nbBindings) override {return impl_.GetBatch(bindings, names, nbBindings)}; - const void* readCalibrationCache(size_t& length) override {return impl_.ReadCalibrationCache(size_t& length)}; - void writeCalibrationCache(const void* cache, size_t length) override {impl_.WriteCalibrationCache(const void* cache, size_t length)}; -private: - Int8CalibratorImpl impl_; -}; - -class TRTInt8MinMaxCalibrator : nvinfer1::IInt8MinMaxCalibrator { -public: - TRTInt8EntropyCalibrator(Int8CalibratorImpl impl) : impl_(impl) {} - int getBatchSize() const override {return impl_.GetBatchSize();} - bool getBatch(void* bindings[], const char* names[], int nbBindings) override {return impl_.GetBatch(bindings, names, nbBindings)}; - const void* readCalibrationCache(size_t& length) override {return impl_.ReadCalibrationCache(size_t& length)}; - void writeCalibrationCache(const void* cache, size_t length) override {impl_.WriteCalibrationCache(const void* cache, size_t length)}; -private: - Int8CalibratorImpl impl_; -}; - -nvinfer1::IInt8Calibrator create_int8_calibrator(QuantizationSettings settings); - -} // quantization -} // core -} // trtorch \ No newline at end of file diff --git a/cpp/api/BUILD b/cpp/api/BUILD index ed34d1ecd5..fc0bb75408 100644 --- a/cpp/api/BUILD +++ b/cpp/api/BUILD @@ -5,12 +5,13 @@ cc_library( hdrs = [ "include/trtorch/trtorch.h", "include/trtorch/logging.h", - "include/trtorch/macros.h" + "include/trtorch/macros.h", + "include/trtorch/ptq.h" ], srcs = [ - "src/trtorch.cpp", "src/extra_info.cpp", - "src/logging.cpp" + "src/logging.cpp", + "src/trtorch.cpp", ], deps = [ "//core", @@ -20,8 +21,8 @@ cc_library( linkstatic = True, alwayslink = True ) - - + + filegroup( name = "api_headers", srcs = glob(["include/**/*.h"]), diff --git a/cpp/api/include/trtorch/ptq.h b/cpp/api/include/trtorch/ptq.h new file mode 100644 index 0000000000..71eaeaaf08 --- /dev/null +++ b/cpp/api/include/trtorch/ptq.h @@ -0,0 +1,162 @@ +#pragma once + +#include +#include +#include +#include +#include + +namespace nvinfer1 { +class IInt8Calibrator; +class IInt8EntropyCalibrator2; +} + +namespace trtorch { +namespace ptq { + +template +class Int8Calibrator : Algorithm { + using DataLoader = typename DataLoaderUniquePtr::element_type; + using Batch = typename DataLoader::super::BatchType; +public: + Int8Calibrator(DataLoaderUniquePtr dataloader, const std::string& cache_file_path, bool use_cache) + : dataloader_(dataloader.get()), it_(dataloader_->begin()), cache_file_path_(cache_file_path), use_cache_(use_cache) {} + + int getBatchSize() const override { + // HACK: TRTorch only uses explict batch sizing, INT8 Calibrator does not + // work when reporting the batch size here and having explicity batching. + // So we just report batch size 1 (warnings will still be printed out). + return 1; + //return static_cast(dataloader_->options().batch_size); + } + + bool getBatch(void* bindings[], const char* names[], int nbBindings) override { + // HACK: doesnt seem like the first try in the initializer list works + if (! it_created_) { + it_ = dataloader_->begin(); + it_created_ = true; + } + + if (it_ == dataloader_->end()) { + return false; + } + + auto batch = *it_; + + for (int i = 0; i < nbBindings; i++) { + auto data = batch.data; + data = data.to(at::kCUDA).contiguous(); + bindings[i] = data.data_ptr(); + } + + it_ = ++it_; + return true; + } + + const void* readCalibrationCache(size_t& length) override { + if (use_cache_) { + std::stringstream ss; + ss << "Reading Calibration Cache from " << cache_file_path_; + logging::log(logging::Level::kINFO, ss.str()); + cache_.clear(); + std::ifstream cache_file(cache_file_path_, std::ios::binary); + cache_file >> std::noskipws; + if (cache_file.good()) { + std::copy(std::istream_iterator(cache_file), + std::istream_iterator(), + std::back_inserter(cache_)); + ss << "Cache read"; + logging::log(logging::Level::kDEBUG, ss.str()); + } + cache_size_ = cache_.size(); + return cache_size_ ? cache_.data() : nullptr; + } + return nullptr; + } + + void writeCalibrationCache(const void* cache, size_t length) override { + std::ofstream cache_file(cache_file_path_, std::ios::binary); + cache_file.write(reinterpret_cast(cache), length); + std::stringstream ss; + ss << "Saved Calibration Cache to " << cache_file_path_; + logging::log(logging::Level::kINFO, ss.str()); + } + + operator nvinfer1::IInt8Calibrator* () { + return reinterpret_cast(this); + } + + ~Int8Calibrator() { + delete dataloader_; + } + +private: + DataLoader* dataloader_; + torch::data::Iterator it_; + const std::string& cache_file_path_; + size_t cache_size_ = 0; + bool use_cache_; + std::vector cache_; + bool it_created_ = false; +}; + +template +class Int8CacheCalibrator : Algorithm { +public: + Int8CacheCalibrator(const std::string& cache_file_path) + : cache_file_path_(cache_file_path) {} + + int getBatchSize() const override { + // HACK: TRTorch only uses explict batch sizing, INT8 Calibrator does not + // work when reporting the batch size here and having explicity batching. + // So we just report batch size 1 (warnings will still be printed out). + return 1; + } + + bool getBatch(void* bindings[], const char* names[], int nbBindings) override { + return false; + } + + const void* readCalibrationCache(size_t& length) override { + std::stringstream ss; + ss << "Reading Calibration Cache from " << cache_file_path_; + logging::log(logging::Level::kINFO, ss.str()); + cache_.clear(); + std::ifstream cache_file; + cache_file.open(cache_file_path_, std::ios::in | std::ios::binary); + cache_file.unsetf(std::ios::skipws); + cache_file.seekg(0, std::ios::beg); + cache_.reserve(cache_file.tellg()); + cache_file.seekg(0, std::ios::beg); + if (cache_file.good()) { + std::cout << "Trying to read cache" << std::endl; + std::copy(std::istreambuf_iterator(cache_file), + std::istreambuf_iterator(), + std::back_inserter(cache_)); + ss << "Cache read"; + logging::log(logging::Level::kDEBUG, ss.str()); + } + cache_size_ = cache_.size(); + return cache_size_ ? cache_.data() : nullptr; + } + + void writeCalibrationCache(const void* cache, size_t length) override { + std::ofstream cache_file(cache_file_path_, std::ios::binary); + cache_file.write(reinterpret_cast(cache), length); + std::stringstream ss; + ss << "Saved Calibration Cache to " << cache_file_path_; + logging::log(logging::Level::kINFO, ss.str()); + } + + operator nvinfer1::IInt8Calibrator* () { + return reinterpret_cast(this); + } + +private: + const std::string& cache_file_path_; + size_t cache_size_ = 0; + std::vector cache_; +}; + +} // namespace ptq +} // namespace trtorch \ No newline at end of file diff --git a/cpp/api/include/trtorch/trtorch.h b/cpp/api/include/trtorch/trtorch.h index 8275b4b0a2..ed5f729382 100644 --- a/cpp/api/include/trtorch/trtorch.h +++ b/cpp/api/include/trtorch/trtorch.h @@ -12,13 +12,16 @@ #include #include +#include "torch/torch.h" +#include "NvInfer.h" + // Just include the .h? namespace torch { namespace jit { struct Graph; namespace script { struct Module; -} // namespace script +} // namespace script } // namespace jit } // namespace torch @@ -29,10 +32,14 @@ template class ArrayRef; } +namespace nvinfer1 { +class IInt8EntropyCalibrator2; +} + #include "trtorch/macros.h" #include "trtorch/logging.h" +#include "trtorch/ptq.h" namespace trtorch { - /** * Settings data structure for TRTorch compilation * @@ -41,7 +48,7 @@ struct TRTORCH_API ExtraInfo { /** * @brief A struct to hold an input range (used by TensorRT Optimization profile) * - * This struct can either hold a single vector representing an input shape, signifying a + * This struct can either hold a single vector representing an input shape, signifying a * static input shape or a set of three input shapes representing the min, optiminal and max * input shapes allowed for the engine. */ @@ -59,7 +66,7 @@ struct TRTORCH_API ExtraInfo { * Supported Data Types that can be used with TensorRT engines * * This class is compatable with c10::DataTypes (but will check for TRT support) - * so there should not be a reason that you need to use this type explictly. + * so there should not be a reason that you need to use this type explictly. */ class DataType { public: @@ -72,14 +79,14 @@ struct TRTORCH_API ExtraInfo { * ex. trtorch::DataType type = DataType::kFloat; */ enum Value : int8_t { - /// FP32 + /// FP32 kFloat, /// FP16 kHalf, /// INT8 - /*kChar, char or int8? */ + kChar, }; - + DataType() = default; constexpr DataType(Value t) : value(t) {} DataType(c10::ScalarType t); @@ -96,7 +103,7 @@ struct TRTORCH_API ExtraInfo { * * This class is compatable with c10::DeviceTypes (but will check for TRT support) * but the only applicable value is at::kCUDA, which maps to DeviceType::kGPU - * + * * To use the DataType class itself, interface using the enum vs. normal instatination * * ex. trtorch::DeviceType type = DeviceType::kGPU; @@ -130,7 +137,7 @@ struct TRTORCH_API ExtraInfo { }; /** - * Emum for selecting engine capability + * Emum for selecting engine capability */ enum class EngineCapability : int8_t { kDEFAULT, @@ -142,24 +149,24 @@ struct TRTORCH_API ExtraInfo { : input_ranges(std::move(input_ranges)) {} ExtraInfo(std::vector> fixed_sizes); ExtraInfo(std::vector> fixed_sizes); - + // Defaults should reflect TensorRT defaults for BuilderConfig - /** + /** * Sizes for inputs to engine, can either be a single size or a range - * defined by Min, Optimal, Max sizes - * - * Order is should match call order + * defined by Min, Optimal, Max sizes + * + * Order is should match call order */ std::vector input_ranges; /** - * Default operating precision for the engine + * Default operating precision for the engine */ DataType op_precision = DataType::kFloat; - + /** - * Build a refitable engine + * Build a refitable engine */ bool refit = false; @@ -174,7 +181,7 @@ struct TRTORCH_API ExtraInfo { bool strict_type = false; /** - * (Only used when targeting DLA (device)) + * (Only used when targeting DLA (device)) * Lets engine run layers on GPU if they are not supported on DLA */ bool allow_gpu_fallback = true; @@ -201,7 +208,12 @@ struct TRTORCH_API ExtraInfo { /** * Maximum size of workspace given to TensorRT */ - uint64_t workspace_size = 0; + uint64_t workspace_size = 1 << 20; + + /** + * Calibration dataloaders for each input for post training quantizatiom + */ + nvinfer1::IInt8Calibrator* ptq_calibrator; }; /** @@ -211,46 +223,84 @@ TRTORCH_API std::string get_build_info(); /** * Dump the version information for TRTorch including base libtorch and TensorRT versions - * to stdout + * to stdout */ TRTORCH_API void dump_build_info(); /** * @brief Check to see if a module is fully supported by the compiler * - * @param module: torch::jit::script::Module - Existing TorchScript module + * @param module: torch::jit::script::Module - Existing TorchScript module * @param method_name: std::string - Name of method to compile * * Takes a module and a method name and checks if the method graph contains purely - * convertable operators - * + * convertable operators + * * Will print out a list of unsupported operators if the graph is unsupported - */ + */ TRTORCH_API bool CheckMethodOperatorSupport(const torch::jit::script::Module& module, std::string method_name); /** * @brief Compile a TorchScript module for NVIDIA GPUs using TensorRT * - * @param module: torch::jit::script::Module - Existing TorchScript module - * @param info: trtorch::ExtraInfo - Compilation settings + * @param module: torch::jit::script::Module - Existing TorchScript module + * @param info: trtorch::ExtraInfo - Compilation settings * * Takes a existing TorchScript module and a set of settings to configure the compiler * and will convert methods to JIT Graphs which call equivalent TensorRT engines * - * Converts specifically the forward method of a TorchScript Module - */ + * Converts specifically the forward method of a TorchScript Module + */ TRTORCH_API torch::jit::script::Module CompileGraph(const torch::jit::script::Module& module, ExtraInfo info); /** * @brief Compile a TorchScript method for NVIDIA GPUs using TensorRT * - * @param module: torch::jit::script::Module - Existing TorchScript module + * @param module: torch::jit::script::Module - Existing TorchScript module * @param method_name: std::string - Name of method to compile - * @param info: trtorch::ExtraInfo - Compilation settings + * @param info: trtorch::ExtraInfo - Compilation settings * * Takes a existing TorchScript module and a set of settings to configure the compiler * and will convert selected method to a serialized TensorRT engine which can be run with * TensorRT */ TRTORCH_API std::string ConvertGraphToTRTEngine(const torch::jit::script::Module& module, std::string method_name, ExtraInfo info); + +namespace ptq { +/** + * @brief A factory to build a post training quantization calibrator from a torch dataloader + * + * Creates a calibrator to use for post training quantization + * If there are multiple inputs, the dataset should produce a example which is a vector (or similar container) of tensors vs a single tensor + * + * By default the returned calibrator uses TensorRT Entropy v2 algorithm to perform calibration. This is recommended for feed forward networks + * You can override the algorithm selection (such as to use the MinMax Calibrator recomended for NLP tasks) by calling make_int8_calibrator with + * the calibrator class as a template parameter. + * + * e.g. trtorch::ptq::make_int8_calibrator(std::move(calibration_dataloader), calibration_cache_file, use_cache); + */ +template +TRTORCH_API inline Int8Calibrator make_int8_calibrator(DataLoader dataloader, const std::string& cache_file_path, bool use_cache) { + return Int8Calibrator(std::move(dataloader), cache_file_path, use_cache); +} + +/** + * @brief A factory to build a post training quantization calibrator from a torch dataloader that only uses the calibration cache + * + * Creates a calibrator to use for post training quantization which reads from a previously created calibration cache, therefore + * you can have a calibration cache generating program that requires a dataloader and a dataset, then save the cache to use later + * in a different program that needs to calibrate from scratch and not have the dataset dependency. However, the network should also + * be recalibrated if its structure changes, or the input data set changes, and it is the responsibility of the application to ensure this. + * + * By default the returned calibrator uses TensorRT Entropy v2 algorithm to perform calibration. This is recommended for feed forward networks + * You can override the algorithm selection (such as to use the MinMax Calibrator recomended for NLP tasks) by calling make_int8_calibrator with + * the calibrator class as a template parameter. + * + * e.g. trtorch::ptq::make_int8_cache_calibrator(calibration_cache_file); + */ +template +TRTORCH_API inline Int8CacheCalibrator make_int8_cache_calibrator(const std::string& cache_file_path) { + return Int8CacheCalibrator(cache_file_path); +} +} // namespace ptq } // namespace trtorch diff --git a/cpp/api/src/extra_info.cpp b/cpp/api/src/extra_info.cpp index 625106a677..23a47e7801 100644 --- a/cpp/api/src/extra_info.cpp +++ b/cpp/api/src/extra_info.cpp @@ -7,7 +7,7 @@ namespace trtorch { ExtraInfo::DataType::DataType(c10::ScalarType t) { - assert(t == at::kHalf || t == at::kFloat /*|| t == at::kChar*/); + TRTORCH_CHECK(t == at::kHalf || t == at::kFloat || t == at::kChar, "Data type is unsupported"); switch (t) { case at::kHalf: value = DataType::kHalf; @@ -16,16 +16,16 @@ ExtraInfo::DataType::DataType(c10::ScalarType t) { default: value = DataType::kFloat; break; - // case at::kChar: - // value = DataType::kChar; + case at::kChar: + value = DataType::kChar; } } ExtraInfo::DeviceType::DeviceType(c10::DeviceType t) { - assert(t == at::kCUDA); + TRTORCH_CHECK(t == at::kCUDA, "Device type when specified using torch device enum must be torch::kCUDA"); value = DeviceType::kGPU; } - + ExtraInfo::InputRange::InputRange(std::vector opt) { this->opt = opt; this->min = opt; @@ -74,51 +74,57 @@ std::vector to_vec_internal_input_ranges(std::vect return internal; } -core::conversion::ExtraInfo to_internal_extra_info(ExtraInfo external) { - core::conversion::ExtraInfo internal(to_vec_internal_input_ranges(external.input_ranges)); +core::ExtraInfo to_internal_extra_info(ExtraInfo external) { + core::ExtraInfo internal(to_vec_internal_input_ranges(external.input_ranges)); switch(external.op_precision) { - // case ExtraInfo::DataType::kChar: - // internal.engine_settings.op_precision = nvinfer1::DataType::kINT8; - // break; + case ExtraInfo::DataType::kChar: + internal.convert_info.engine_settings.op_precision = nvinfer1::DataType::kINT8; + break; case ExtraInfo::DataType::kHalf: - internal.engine_settings.op_precision = nvinfer1::DataType::kHALF; + internal.convert_info.engine_settings.op_precision = nvinfer1::DataType::kHALF; break; case ExtraInfo::DataType::kFloat: default: - internal.engine_settings.op_precision = nvinfer1::DataType::kFLOAT; + internal.convert_info.engine_settings.op_precision = nvinfer1::DataType::kFLOAT; } - - internal.engine_settings.refit = external.refit; - internal.engine_settings.debug = external.debug; - internal.engine_settings.strict_type = external.strict_type; - internal.engine_settings.allow_gpu_fallback = external.allow_gpu_fallback; + + internal.convert_info.engine_settings.refit = external.refit; + internal.convert_info.engine_settings.debug = external.debug; + internal.convert_info.engine_settings.strict_type = external.strict_type; + internal.convert_info.engine_settings.allow_gpu_fallback = external.allow_gpu_fallback; switch(external.device) { case ExtraInfo::DeviceType::kDLA: - internal.engine_settings.device = nvinfer1::DeviceType::kDLA; + internal.convert_info.engine_settings.device = nvinfer1::DeviceType::kDLA; break; case ExtraInfo::DeviceType::kGPU: default: - internal.engine_settings.device = nvinfer1::DeviceType::kGPU; + internal.convert_info.engine_settings.device = nvinfer1::DeviceType::kGPU; } switch(external.capability) { case ExtraInfo::EngineCapability::kSAFE_GPU: - internal.engine_settings.capability = nvinfer1::EngineCapability::kSAFE_GPU; + internal.convert_info.engine_settings.capability = nvinfer1::EngineCapability::kSAFE_GPU; break; case ExtraInfo::EngineCapability::kSAFE_DLA: - internal.engine_settings.capability = nvinfer1::EngineCapability::kSAFE_DLA; + internal.convert_info.engine_settings.capability = nvinfer1::EngineCapability::kSAFE_DLA; break; case ExtraInfo::EngineCapability::kDEFAULT: default: - internal.engine_settings.capability = nvinfer1::EngineCapability::kDEFAULT; - + internal.convert_info.engine_settings.capability = nvinfer1::EngineCapability::kDEFAULT; + + } + + internal.convert_info.engine_settings.num_min_timing_iters = external.num_min_timing_iters; + internal.convert_info.engine_settings.num_avg_timing_iters = external.num_avg_timing_iters; + internal.convert_info.engine_settings.workspace_size = external.workspace_size; + + if (internal.convert_info.engine_settings.op_precision == nvinfer1::DataType::kINT8) { + internal.convert_info.engine_settings.calibrator = external.ptq_calibrator; + } else { + internal.convert_info.engine_settings.calibrator = nullptr; } - - internal.engine_settings.num_min_timing_iters = external.num_min_timing_iters; - internal.engine_settings.num_avg_timing_iters = external.num_avg_timing_iters; - internal.engine_settings.workspace_size = external.workspace_size; return internal; } diff --git a/cpp/api/src/trtorch.cpp b/cpp/api/src/trtorch.cpp index 562f4faa9f..bb8e5a7845 100644 --- a/cpp/api/src/trtorch.cpp +++ b/cpp/api/src/trtorch.cpp @@ -8,7 +8,7 @@ namespace trtorch { // Defined in extra_info.cpp -core::conversion::ExtraInfo to_internal_extra_info(ExtraInfo external); +core::ExtraInfo to_internal_extra_info(ExtraInfo external); bool CheckMethodOperatorSupport(const torch::jit::script::Module& module, std::string method_name) { diff --git a/cpp/ptq/BUILD b/cpp/ptq/BUILD index f17db99cc0..aa2762494e 100644 --- a/cpp/ptq/BUILD +++ b/cpp/ptq/BUILD @@ -3,10 +3,19 @@ package(default_visibility = ["//visibility:public"]) cc_binary( name = "ptq", srcs = [ - "main.cpp" + "main.cpp", + "timer.h" ], deps = [ + "//cpp/ptq/datasets:cifar10", "@libtorch//:libtorch", - "//cpp/api:trtorch" + "//cpp/api:trtorch", + "@tensorrt//:nvinfer" ], + copts = [ + "-pthread" + ], + linkopts = [ + "-lpthread", + ] ) diff --git a/cpp/ptq/datasets/BUILD b/cpp/ptq/datasets/BUILD new file mode 100644 index 0000000000..32d9fc5de7 --- /dev/null +++ b/cpp/ptq/datasets/BUILD @@ -0,0 +1,14 @@ +package(default_visibility = ["//visibility:public"]) + +cc_library( + name = "cifar10", + hdrs = [ + "cifar10.h" + ], + srcs = [ + "cifar10.cpp" + ], + deps = [ + "@libtorch//:libtorch" + ] +) \ No newline at end of file diff --git a/cpp/ptq/datasets/cifar10.cpp b/cpp/ptq/datasets/cifar10.cpp new file mode 100644 index 0000000000..d7768b91eb --- /dev/null +++ b/cpp/ptq/datasets/cifar10.cpp @@ -0,0 +1,119 @@ +#include "cpp/ptq/datasets/cifar10.h" + +#include "torch/data/example.h" +#include "torch/types.h" + +#include +#include +#include +#include +#include +#include +#include +#include + +namespace datasets { +namespace { +constexpr const char* kTrainFilenamePrefix = "data_batch_"; +constexpr const uint32_t kNumTrainFiles = 5; +constexpr const char* kTestFilename = "test_batch.bin"; +constexpr const size_t kLabelSize = 1; // B +constexpr const size_t kImageSize = 3072; // B +constexpr const size_t kImageDim = 32; +constexpr const size_t kImageChannels = 3; +constexpr const size_t kBatchSize = 10000; + +std::pair read_batch(const std::string& path) { + std::ifstream batch; + batch.open(path, std::ios::in|std::ios::binary|std::ios::ate); + + auto file_size = batch.tellg(); + std::unique_ptr buf(new char[file_size]); + + batch.seekg(0, std::ios::beg); + batch.read(buf.get(), file_size); + batch.close(); + + std::vector labels; + std::vector images; + labels.reserve(kBatchSize); + images.reserve(kBatchSize); + + for (size_t i = 0; i < kBatchSize; i++) { + uint8_t label = buf[i * (kImageSize + kLabelSize)]; + std::vector image; + image.reserve(kImageSize); + std::copy(&buf[i * (kImageSize + kLabelSize) + 1], &buf[i * (kImageSize + kLabelSize) + kImageSize], std::back_inserter(image)); + labels.push_back(label); + auto image_tensor = torch::from_blob(image.data(), + {kImageChannels, kImageDim, kImageDim}, + torch::TensorOptions().dtype(torch::kU8)).to(torch::kF32); + images.push_back(image_tensor); + } + + auto labels_tensor = torch::from_blob(labels.data(), + {kBatchSize}, + torch::TensorOptions().dtype(torch::kU8)).to(torch::kF32); + assert(labels_tensor.size(0) == kBatchSize); + + auto images_tensor = torch::stack(images); + assert(images_tensor.size(0) == kBatchSize); + + return std::make_pair(images_tensor, labels_tensor); +} + +std::pair read_train_data(const std::string& root) { + torch::Tensor images, targets; + for(uint32_t i = 1; i <= 5; i++) { + std::stringstream ss; + ss << root << '/' << kTrainFilenamePrefix << i << ".bin"; + auto batch = read_batch(ss.str()); + images = torch::stack({images, batch.first}); + targets = torch::stack({targets, batch.second}); + } + return std::make_pair(images, targets); +} + +std::pair read_test_data(const std::string& root) { + std::stringstream ss; + ss << root << '/' << kTestFilename; + return read_batch(ss.str()); +} +} + +CIFAR10::CIFAR10(const std::string& root, Mode mode) + : mode_(mode) { + + std::pair data; + if (mode_ == Mode::kTrain) { + data = read_train_data(root); + } else { + data = read_test_data(root); + } + + images_ = std::move(data.first); + targets_ = std::move(data.second); +} + +torch::data::Example<> CIFAR10::get(size_t index) { + return {images_[index], targets_[index]}; +} + +c10::optional CIFAR10::size() const { + return images_.size(0); +} + +bool CIFAR10::is_train() const noexcept { + return mode_ == Mode::kTrain; +} + +const torch::Tensor& CIFAR10::images() const { + return images_; +} + +const torch::Tensor& CIFAR10::targets() const { + return targets_; +} + +} // namespace datasets + diff --git a/cpp/ptq/datasets/cifar10.h b/cpp/ptq/datasets/cifar10.h new file mode 100644 index 0000000000..ba6ecb8d37 --- /dev/null +++ b/cpp/ptq/datasets/cifar10.h @@ -0,0 +1,41 @@ +#pragma once + +#include "torch/data/datasets/base.h" +#include "torch/data/example.h" +#include "torch/types.h" + +#include +#include + +namespace datasets { +// The CIFAR10 Dataset +class CIFAR10 : public torch::data::datasets::Dataset { +public: + // The mode in which the dataset is loaded + enum class Mode { kTrain, kTest }; + + // Loads CIFAR10 from un-tarred file + // Dataset can be found https://www.cs.toronto.edu/~kriz/cifar-10-binary.tar.gz + // Root path should be the directory that contains the content of tarball + explicit CIFAR10(const std::string& root, Mode mode = Mode::kTrain); + + // Returns the pair at index in the dataset + torch::data::Example<> get(size_t index) override; + + // The size of the dataset + c10::optional size() const override; + + // The mode the dataset is in + bool is_train() const noexcept; + + // Returns all images stacked into a single tensor + const torch::Tensor& images() const; + + // Returns all targets stacked into a single tensor + const torch::Tensor& targets() const; + +private: + Mode mode_; + torch::Tensor images_, targets_; +}; +} // namespace datasets diff --git a/cpp/ptq/main.cpp b/cpp/ptq/main.cpp index f7fe857c99..6499612eb8 100644 --- a/cpp/ptq/main.cpp +++ b/cpp/ptq/main.cpp @@ -1,36 +1,112 @@ #include "torch/script.h" -#include "torch/csrc/api/include/torch/data/datasets/mnist.h" +#include "torch/torch.h" #include "trtorch/trtorch.h" +#include "NvInfer.h" + +#include "datasets/cifar10.h" +#include "timer.h" + #include #include #include +#include int main(int argc, const char* argv[]) { + trtorch::logging::set_reportable_log_level(trtorch::logging::kINFO); if (argc < 3) { - std::cerr << "usage: ptq \n"; + std::cerr << "usage: ptq \n"; return -1; } torch::jit::script::Module mod; try { - // Deserialize the ScriptModule from a file using torch::jit::load(). - mod = torch::jit::load(argv[1]); + // Deserialize the ScriptModule from a file using torch::jit::load(). + mod = torch::jit::load(argv[1]); } catch (const c10::Error& e) { - std::cerr << "error loading the model\n"; - return -1; + std::cerr << "error loading the model\n"; + return -1; } + // Create the calibration dataset const std::string data_dir = std::string(argv[2]); - auto calibration_dataset = torch::data::datasets::MNIST(data_dir, torch::data::datasets::MNIST::Mode::kTest) - .map(torch::data::transforms::Normalize<>(0.1307, 0.3081)) + auto calibration_dataset = datasets::CIFAR10(data_dir, datasets::CIFAR10::Mode::kTest) + .map(torch::data::transforms::Normalize<>({0.4914, 0.4822, 0.4465}, + {0.2023, 0.1994, 0.2010})) .map(torch::data::transforms::Stack<>()); auto calibration_dataloader = torch::data::make_data_loader(std::move(calibration_dataset), torch::data::DataLoaderOptions() .batch_size(32) - .workers(1)) + .workers(2)); + + std::string calibration_cache_file = "/tmp/vgg16_TRT_ptq_calibration.cache"; + + //auto calibrator = trtorch::ptq::make_int8_calibrator(std::move(calibration_dataloader), calibration_cache_file, true); + auto calibrator = trtorch::ptq::make_int8_cache_calibrator(calibration_cache_file); + + + std::vector> input_shape = {{32, 3, 32, 32}}; + // Configure settings for compilation + auto extra_info = trtorch::ExtraInfo({input_shape}); + // Set operating precision to INT8 + extra_info.op_precision = torch::kChar; + // Use the TensorRT Entropy Calibrator + extra_info.ptq_calibrator = calibrator; + // Increase the default workspace size; + extra_info.workspace_size = 1 << 30; + + mod.eval(); + + // Dataloader moved into calibrator so need another for inference + auto eval_dataset = datasets::CIFAR10(data_dir, datasets::CIFAR10::Mode::kTest) + .map(torch::data::transforms::Normalize<>({0.4914, 0.4822, 0.4465}, + {0.2023, 0.1994, 0.2010})) + .map(torch::data::transforms::Stack<>()); + auto eval_dataloader = torch::data::make_data_loader(std::move(eval_dataset), torch::data::DataLoaderOptions() + .batch_size(32) + .workers(2)); - for (auto batch : batched_calibration_dataset) { - std::cout << batch.data().sizes() << std::endl; + // Check the FP32 accuracy in JIT + float correct = 0.0, total = 0.0; + for (auto batch : *eval_dataloader) { + auto images = batch.data.to(torch::kCUDA); + auto targets = batch.target.to(torch::kCUDA); + + auto outputs = mod.forward({images}); + auto predictions = std::get<1>(torch::max(outputs.toTensor(), 1, false)); + + total += targets.sizes()[0]; + correct += torch::sum(torch::eq(predictions, targets)).item().toFloat(); + } + std::cout << "Accuracy of JIT model on test set: " << 100 * (correct / total) << "%" << std::endl; + + // Compile Graph + auto trt_mod = trtorch::CompileGraph(mod, extra_info); + + // Check the INT8 accuracy in TRT + correct = 0.0; + total = 0.0; + for (auto batch : *eval_dataloader) { + auto images = batch.data.to(torch::kCUDA); + auto targets = batch.target.to(torch::kCUDA); + + auto outputs = trt_mod.forward({images}); + auto predictions = std::get<1>(torch::max(outputs.toTensor(), 1, false)); + + total += targets.sizes()[0]; + correct += torch::sum(torch::eq(predictions, targets)).item().toFloat(); + std::cout << total << " " << correct << std::endl; } + std::cout << total << " " << correct << std::endl; + std::cout << "Accuracy of quantized model on test set: " << 100 * (correct / total) << "%" << std::endl; + + // Time execution in INT8 + auto execution_timer = timers::PreciseCPUTimer(); + auto images = (*(*eval_dataloader).begin()).data.to(torch::kCUDA); + + execution_timer.start(); + trt_mod.forward({images}); + execution_timer.stop(); + + std::cout << "Latency of quantized model (Batch Size 32): " << execution_timer.milliseconds() << "ms" << std::endl; } diff --git a/cpp/ptq/timer.h b/cpp/ptq/timer.h new file mode 100644 index 0000000000..cef81e8629 --- /dev/null +++ b/cpp/ptq/timer.h @@ -0,0 +1,39 @@ +#pragma once +#include + +namespace timers +{ +class TimerBase +{ +public: + virtual void start() {} + virtual void stop() {} + float microseconds() const noexcept { return mMs * 1000.f; } + float milliseconds() const noexcept { return mMs; } + float seconds() const noexcept { return mMs / 1000.f; } + void reset() noexcept { mMs = 0.f; } + +protected: + float mMs{0.0f}; +}; + + +template +class CPUTimer : public TimerBase +{ +public: + using clock_type = Clock; + + void start() { mStart = Clock::now(); } + void stop() + { + mStop = Clock::now(); + mMs += std::chrono::duration{mStop - mStart}.count(); + } + +private: + std::chrono::time_point mStart, mStop; +}; // class CPUTimer + +using PreciseCPUTimer = CPUTimer; +} // namespace timers diff --git a/cpp/ptq/training/vgg16/export_ckpt.py b/cpp/ptq/training/vgg16/export_ckpt.py new file mode 100644 index 0000000000..de18e0f632 --- /dev/null +++ b/cpp/ptq/training/vgg16/export_ckpt.py @@ -0,0 +1,73 @@ +import argparse +import os + +import torch +import torch.nn as nn +import torch.nn.functional as F +import torch.utils.data as data +import torchvision.transforms as transforms +import torchvision.datasets as datasets + +from vgg16 import vgg16 + +def test(model, dataloader, crit): + global writer + global classes + total = 0 + correct = 0 + loss = 0.0 + class_probs = [] + class_preds = [] + model.eval() + with torch.no_grad(): + for data, labels in dataloader: + data, labels = data.cuda(), labels.cuda(async=True) + out = model(data) + loss += crit(out, labels) + preds = torch.max(out, 1)[1] + class_probs.append([F.softmax(i, dim=0) for i in out]) + class_preds.append(preds) + total += labels.size(0) + correct += (preds == labels).sum().item() + + test_probs = torch.cat([torch.stack(batch) for batch in class_probs]) + test_preds = torch.cat(class_preds) + return loss / total, correct / total + +PARSER = argparse.ArgumentParser(description="Export trained VGG") +PARSER.add_argument('ckpt', type=str, help="Path to saved checkpoint") + +args = PARSER.parse_args() +model = vgg16(num_classes=10, init_weights=False) +model = model.cuda() + +ckpt = torch.load(args.ckpt) +weights = ckpt["model_state_dict"] + +if torch.cuda.device_count() > 1: + from collections import OrderedDict + new_state_dict = OrderedDict() + for k, v in weights.items(): + name = k[7:] # remove `module.` + new_state_dict[name] = v + weights = new_state_dict + +model.load_state_dict(weights) + +jit_model = torch.jit.trace(model, torch.rand([32, 3, 32, 32]).to("cuda")) + +testing_dataset = datasets.CIFAR10(root='./data', train=False, download=True, + transform=transforms.Compose([ + transforms.ToTensor(), + transforms.Normalize((0.4914, 0.4822, 0.4465), + (0.2023, 0.1994, 0.2010))])) + +testing_dataloader = torch.utils.data.DataLoader(testing_dataset, batch_size=32, + shuffle=False, num_workers=2) + +crit = torch.nn.CrossEntropyLoss() +test_loss, test_acc = test(model, testing_dataloader, crit) +print("[PTH] Test Loss: {:.5f} Test Acc: {:.2f}%".format(test_loss, 100 * test_acc)) +print("[JIT] Test Loss: {:.5f} Test Acc: {:.2f}%".format(test_loss, 100 * test_acc)) + +torch.jit.save(jit_model, "trained_vgg16.jit.pt") diff --git a/cpp/ptq/training/vgg16/main.py b/cpp/ptq/training/vgg16/main.py index 13c97fb3ca..964441aac8 100644 --- a/cpp/ptq/training/vgg16/main.py +++ b/cpp/ptq/training/vgg16/main.py @@ -16,7 +16,7 @@ from vgg16 import vgg16 PARSER = argparse.ArgumentParser(description="VGG16 example to use with TRTorch PTQ") -PARSER.add_argument('--epochs', default=300, type=int, help="Number of total epochs to train") +PARSER.add_argument('--epochs', default=100, type=int, help="Number of total epochs to train") PARSER.add_argument('--batch-size', default=128, type=int, help="Batch size to use when training") PARSER.add_argument('--lr', default=0.1, type=float, help="Initial learning rate") PARSER.add_argument('--drop-ratio', default=0., type=float, help="Dropout ratio") @@ -89,6 +89,9 @@ def main(): crit = nn.CrossEntropyLoss() opt = optim.SGD(model.parameters(), lr=args.lr, momentum=args.momentum, weight_decay=args.weight_decay) + if torch.cuda.device_count() > 1: + model = nn.DataParallel(model) + if args.start_from != 0: ckpt_file = args.ckpt_dir + '/ckpt_epoch' + str(args.start_from) + '.pth' print('Loading from checkpoint {}'.format(ckpt_file)) @@ -98,9 +101,6 @@ def main(): opt.load_state_dict(ckpt["opt_state_dict"]) state = ckpt["state"] - if torch.cuda.device_count() > 1: - model = nn.DataParallel(model) - for epoch in range(args.start_from, args.epochs): adjust_lr(opt, epoch) writer.add_scalar('Learning Rate', state["lr"], epoch) @@ -170,6 +170,7 @@ def test(model, dataloader, crit, epoch): test_preds = torch.cat(class_preds) for i in range(len(classes)): add_pr_curve_tensorboard(i, test_probs, test_preds, epoch) + #print(loss, total, correct, total) return loss / total, correct / total @@ -181,7 +182,7 @@ def save_checkpoint(state, ckpt_dir='checkpoint'): def adjust_lr(optimizer, epoch): global state - new_lr = state["lr"] * (0.5 ** (epoch // 50)) if state["lr"] > 1e-7 else state["lr"] + new_lr = state["lr"] * (0.5 ** (epoch // 40)) if state["lr"] > 1e-7 else state["lr"] if new_lr != state["lr"]: state["lr"] = new_lr print("Updating learning rate: {}".format(state["lr"])) diff --git a/cpp/ptq/training/vgg16/test.py b/cpp/ptq/training/vgg16/test.py new file mode 100644 index 0000000000..be189e2c04 --- /dev/null +++ b/cpp/ptq/training/vgg16/test.py @@ -0,0 +1,20 @@ +import argparse +import os +import random +from datetime import datetime + +import torch +import torch.nn as nn +import torch.nn.functional as F +import torch.optim as optim +import torch.utils.data as data +import torchvision.transforms as transforms +import torchvision.datasets as datasets + +from torch.utils.tensorboard import SummaryWriter + +from vgg16 import vgg16 + +model = vgg16(num_classes=10, init_weights=False) + +model.forward(torch.rand([1,3,224,224])) \ No newline at end of file diff --git a/cpp/ptq/training/vgg16/vgg16.py b/cpp/ptq/training/vgg16/vgg16.py index 0210063be7..7a0b496a78 100644 --- a/cpp/ptq/training/vgg16/vgg16.py +++ b/cpp/ptq/training/vgg16/vgg16.py @@ -1,3 +1,8 @@ +''' +# Reference +- [Very Deep Convolutional Networks for Large-Scale Image Recognition]( + https://arxiv.org/abs/1409.1556) (ICLR 2015) +''' import torch import torch.nn as nn import torch.nn.functional as F @@ -21,9 +26,9 @@ def __init__(self, layer_spec, num_classes=1000, init_weights=False): in_channels = l self.features = nn.Sequential(*layers) - self.avgpool = nn.AdaptiveAvgPool2d((7, 7)) + self.avgpool = nn.AdaptiveAvgPool2d((1, 1)) self.classifier = nn.Sequential( - nn.Linear(512 * 7 * 7, 4096), + nn.Linear(512 * 1 * 1, 4096), nn.ReLU(), nn.Dropout(), nn.Linear(4096, 4096), @@ -55,5 +60,5 @@ def forward(self, x): return x def vgg16(num_classes=1000, init_weights=False): - vgg16_cfg = [64, 64, 'pool', 128, 128, 'pool', 256, 256, 256, 256, 'pool', 512, 512, 512, 512, 'pool', 512, 512, 512, 512, 'pool'] + vgg16_cfg = [64, 64, 'pool', 128, 128, 'pool', 256, 256, 256, 'pool', 512, 512, 512, 'pool', 512, 512, 512, 'pool'] return VGG(vgg16_cfg, num_classes, init_weights) \ No newline at end of file diff --git a/third_party/libtorch/BUILD b/third_party/libtorch/BUILD index ae848c5536..0c5bff0d2f 100644 --- a/third_party/libtorch/BUILD +++ b/third_party/libtorch/BUILD @@ -9,21 +9,24 @@ cc_library( cc_library( name = 'torch', - hdrs = glob([ - 'include/torch/**/*.h', - ], - exclude = ['include/torch/csrc/api/include/**/*.h'] + hdrs = glob( + [ + 'include/torch/**/*.h', + ], exclude = [ + 'include/torch/csrc/api/include/**/*.h' + ] ) + glob([ 'include/torch/csrc/api/include/**/*.h' ]), srcs = ['lib/libtorch.so'], - strip_include_prefix = "include", deps = [ ":ATen", ":torch_deps", ":c10_cuda", - #"@cuda//:cudart", - #"@cuda//:nvToolsExt" + ], + includes = [ + "include", + "include/torch/csrc/api/include/" ] )