From 521a0cbc227af4f1fe7439c75103bb2889451484 Mon Sep 17 00:00:00 2001 From: Dheeraj Peri Date: Sun, 25 Jul 2021 18:10:27 -0700 Subject: [PATCH] fix: Final working version of QAT in TRTorch Signed-off-by: Dheeraj Peri --- core/compiler.cpp | 6 +-- .../conversionctx/ConversionCtx.cpp | 2 +- core/conversion/conversionctx/ConversionCtx.h | 2 + .../converters/impl/matrix_multiply.cpp | 46 +------------------ core/conversion/converters/impl/shuffle.cpp | 19 +++++--- core/conversion/evaluators/aten.cpp | 30 ------------ core/lowering/lowering.cpp | 29 ++++++++---- core/lowering/lowering.h | 5 +- cpp/api/include/trtorch/trtorch.h | 12 ++--- cpp/api/src/compile_spec.cpp | 8 +++- py/trtorch/csrc/register_tensorrt_classes.cpp | 2 +- py/trtorch/csrc/tensorrt_backend.cpp | 2 +- py/trtorch/csrc/tensorrt_classes.cpp | 11 ++++- .../conversion/converters/test_shuffle.cpp | 25 +++++++++- tests/modules/hub.py | 8 ---- tests/util/run_graph_engine.cpp | 2 +- 16 files changed, 92 insertions(+), 117 deletions(-) diff --git a/core/compiler.cpp b/core/compiler.cpp index 1f7ab3aa47..ae07920ca3 100644 --- a/core/compiler.cpp +++ b/core/compiler.cpp @@ -119,7 +119,7 @@ void AddEngineToGraph( bool CheckMethodOperatorSupport(const torch::jit::script::Module& mod, std::string method_name) { // Go through Lowering to simplify graph and extract weight parameters - auto graph_and_parameters = lowering::Lower(mod, method_name); + auto graph_and_parameters = lowering::Lower(mod, method_name, false); auto g = graph_and_parameters.first; LOG_DEBUG(*g << "(CheckMethodOperatorSupport)\n"); @@ -129,7 +129,7 @@ bool CheckMethodOperatorSupport(const torch::jit::script::Module& mod, std::stri std::string ConvertGraphToTRTEngine(const torch::jit::script::Module& mod, std::string method_name, CompileSpec cfg) { // Go through Lowering to simplify graph and extract weight parameters - auto graph_and_parameters = lowering::Lower(mod, method_name); + auto graph_and_parameters = lowering::Lower(mod, method_name, cfg.convert_info.engine_settings.unfreeze_module); auto convert_cfg = std::move(cfg.convert_info); auto g = graph_and_parameters.first; @@ -187,7 +187,7 @@ torch::jit::script::Module CompileGraphWithFallback(const torch::jit::script::Mo // Compile only forward methods. forward method contains the entire graph. if (method.name().compare("forward") == 0) { auto new_g = std::make_shared(); - auto graph_and_parameters = lowering::Lower(mod, method.name()); + auto graph_and_parameters = lowering::Lower(mod, method.name(), cfg.convert_info.engine_settings.unfreeze_module); auto g = graph_and_parameters.first; auto params = graph_and_parameters.second; diff --git a/core/conversion/conversionctx/ConversionCtx.cpp b/core/conversion/conversionctx/ConversionCtx.cpp index 21a50d9662..bb88471186 100644 --- a/core/conversion/conversionctx/ConversionCtx.cpp +++ b/core/conversion/conversionctx/ConversionCtx.cpp @@ -72,7 +72,7 @@ ConversionCtx::ConversionCtx(BuilderSettings build_settings) if (!settings.calibrator) { LOG_WARNING( "Int8 precision has been enabled but no calibrator provided. This assumes the network has Q/DQ nodes obtained from Quantization aware training. For more details, refer to https://docs.nvidia.com/deeplearning/tensorrt/developer-guide/index.html#work-with-qat-networks"); - } else{ + } else { cfg->setInt8Calibrator(settings.calibrator); } break; diff --git a/core/conversion/conversionctx/ConversionCtx.h b/core/conversion/conversionctx/ConversionCtx.h index c7f9609776..c649ad098c 100644 --- a/core/conversion/conversionctx/ConversionCtx.h +++ b/core/conversion/conversionctx/ConversionCtx.h @@ -27,6 +27,8 @@ struct BuilderSettings { bool sparse_weights = false; std::set enabled_precisions = {nvinfer1::DataType::kFLOAT}; bool disable_tf32 = false; + // Internal flag to ensure torch.jit.Module does not get freezed in lowering.cpp. This is required for QAT models. + bool unfreeze_module = false; bool refit = false; bool debug = false; bool strict_types = false; diff --git a/core/conversion/converters/impl/matrix_multiply.cpp b/core/conversion/converters/impl/matrix_multiply.cpp index a84a853ccf..5d169fbedc 100644 --- a/core/conversion/converters/impl/matrix_multiply.cpp +++ b/core/conversion/converters/impl/matrix_multiply.cpp @@ -26,6 +26,7 @@ auto mm_registrations TRTORCH_UNUSED = auto mm_layer = ctx->net->addMatrixMultiply( *self, nvinfer1::MatrixOperation::kNONE, *other, nvinfer1::MatrixOperation::kNONE); + TRTORCH_CHECK(mm_layer, "Unable to create matrix multiplication node: " << *n); mm_layer->setName(util::node_info(n).c_str()); auto out_tensor = ctx->AssociateValueAndTensor(n->outputs()[0], mm_layer->getOutput(0)); @@ -73,51 +74,6 @@ auto mm_registrations TRTORCH_UNUSED = LOG_DEBUG("Output tensor shape: " << out_tensor->getDimensions()); return true; - }}) - .pattern( - {"aten::addmm(Tensor self, Tensor mat1, Tensor mat2, *, Scalar beta=1, Scalar alpha=1) -> (Tensor)", - [](ConversionCtx* ctx, const torch::jit::Node* n, args& args) -> bool { - auto self = args[0].ITensorOrFreeze(ctx); - auto mat1 = args[1].ITensorOrFreeze(ctx); - auto mat2 = args[2].ITensorOrFreeze(ctx); - auto beta = args[3].unwrapToScalar().to(); - auto betaTensor = tensor_to_const(ctx, torch::tensor({beta})); - auto alpha = args[4].unwrapToScalar().to(); - auto alphaTensor = tensor_to_const(ctx, torch::tensor({alpha})); - - // Ensure self and other tensors have same nbDims by expanding the dimensions (from 0 axis) if - // necessary. - if (mat1->getDimensions().nbDims < mat2->getDimensions().nbDims) { - mat1 = addPadding(ctx, n, mat1, mat2->getDimensions().nbDims, false, false); - } else { - mat2 = addPadding(ctx, n, mat2, mat1->getDimensions().nbDims, false, false); - } - - auto mm_layer = ctx->net->addMatrixMultiply( - *mat1, nvinfer1::MatrixOperation::kNONE, *mat2, nvinfer1::MatrixOperation::kNONE); - TRTORCH_CHECK(mm_layer, "Unable to create matrix multiplication layer in node: " << *n); - auto mm_scale_layer = add_elementwise( - ctx, - nvinfer1::ElementWiseOperation::kPROD, - mm_layer->getOutput(0), - alphaTensor, - util::node_info(n) + "_alphaScale"); - TRTORCH_CHECK(mm_scale_layer, "Unable to create alpha scaling layer in node: " << *n); - auto beta_scale_layer = add_elementwise( - ctx, nvinfer1::ElementWiseOperation::kPROD, self, betaTensor, util::node_info(n) + "_betaScale"); - TRTORCH_CHECK(beta_scale_layer, "Unable to create beta scaling layer in node: " << *n); - auto add_mm_layer = add_elementwise( - ctx, - nvinfer1::ElementWiseOperation::kSUM, - beta_scale_layer->getOutput(0), - mm_scale_layer->getOutput(0), - util::node_info(n)); - TRTORCH_CHECK(add_mm_layer, "Unable to create addmm layer in node: " << *n); - - auto out_tensor = ctx->AssociateValueAndTensor(n->outputs()[0], mm_layer->getOutput(0)); - - LOG_DEBUG("[AddMM layer] Output tensor shape: " << out_tensor->getDimensions()); - return true; }}); } // namespace } // namespace impl diff --git a/core/conversion/converters/impl/shuffle.cpp b/core/conversion/converters/impl/shuffle.cpp index a78db8f3c8..f27b2b7797 100644 --- a/core/conversion/converters/impl/shuffle.cpp +++ b/core/conversion/converters/impl/shuffle.cpp @@ -131,15 +131,22 @@ static auto shuffle_registrations TRTORCH_UNUSED = [](ConversionCtx* ctx, const torch::jit::Node* n, args& args) -> bool { auto in = args[0].ITensorOrFreeze(ctx); auto input_dims = in->getDimensions(); - nvinfer1::Dims transposed_input_dims; - transposed_input_dims.nbDims = input_dims.nbDims; - for (int i = input_dims.nbDims - 1; i >= 0; i--) { - transposed_input_dims.d[i] = input_dims.d[input_dims.nbDims - 1 - i]; + // For input tensors < 2D, return them as is + // For a 2D input tensor, return transpose(input, 0, 1) which is a general 2d matrix transpose. + if (input_dims.nbDims < 2) { + auto out_tensor = ctx->AssociateValueAndTensor(n->outputs()[0], in); + LOG_DEBUG("Output tensor shape: " << out_tensor->getDimensions()); + return true; } + auto shuffle_layer = ctx->net->addShuffle(*in); TRTORCH_CHECK(shuffle_layer, "Unable to create shuffle layer from node: " << *n); - shuffle_layer->setReshapeDimensions(transposed_input_dims); - shuffle_layer->setZeroIsPlaceholder(true); + nvinfer1::Permutation firstPerm; + firstPerm.order[0] = 1; + firstPerm.order[1] = 0; + + shuffle_layer->setFirstTranspose(firstPerm); + shuffle_layer->setZeroIsPlaceholder(false); shuffle_layer->setName(util::node_info(n).c_str()); auto out_tensor = ctx->AssociateValueAndTensor(n->outputs()[0], shuffle_layer->getOutput(0)); diff --git a/core/conversion/evaluators/aten.cpp b/core/conversion/evaluators/aten.cpp index df2cf84944..a478b3ff0b 100644 --- a/core/conversion/evaluators/aten.cpp +++ b/core/conversion/evaluators/aten.cpp @@ -427,36 +427,6 @@ auto aten_registrations TRTORCH_UNUSED = EvalOptions().validSchemas({ "aten::numel(Tensor self) -> int", })}) - // .evaluator({c10::Symbol::fromQualString("aten::t"), - // [](const torch::jit::Node* n, kwargs& args) -> c10::optional { - // auto tensor_var = args.at(n->input(0)); - // if (tensor_var.isIValue() && tensor_var.IValue()->isTensor()) { - // auto tensor = tensor_var.unwrapToTensor(); - // return tensor.t(); - // } else if (tensor_var.isITensor()) { - // auto input_tensor = tensor_var.ITensor(); - // auto input_dims = input_tensor->getDimensions(); - // LOG_DEBUG("[aten::t] INPUT TENSOR DIMS: " << input_dims); - // // nvinfer1::Dims transposed_input_dims; - // // for (int i = input_dims.nbDims - 1; i >= 0; i--) { - // // transposed_input_dims.d[i] = input_dims.d[input_dims.nbDims - 1 - i]; - // // } - // // auto shuffle_layer = ctx->net->addShuffle(*input_tensor); - // // shuffle_layer->setReshapeDimensions(transposed_input_dims); - // // shuffle_layer->setZeroIsPlaceholder(true); - // // auto output_tensor = shuffle_layer->getOutput(0); - // auto tensor_holder = TensorContainer(); - // tensor_holder.hold_tensor(input_tensor); - // auto ival = c10::IValue(std::move(c10::make_intrusive(tensor_holder))); - // return ival; - // } else { - // TRTORCH_THROW_ERROR("Unimplemented data type for aten::t evaluator: ITensor"); - // return {}; - // } - // }, - // EvalOptions().validSchemas({ - // "aten::t(Tensor self) -> Tensor", - // })}) .evaluator({c10::Symbol::fromQualString("aten::dim"), [](const torch::jit::Node* n, kwargs& args) -> c10::optional { auto tensor_var = args.at(n->input(0)); diff --git a/core/lowering/lowering.cpp b/core/lowering/lowering.cpp index cad83c0585..fe2c727810 100644 --- a/core/lowering/lowering.cpp +++ b/core/lowering/lowering.cpp @@ -24,7 +24,7 @@ void LowerBlock(torch::jit::Block* b) { DropUnusedNodes(b); } -void LowerGraph(std::shared_ptr& g) { +void LowerGraph(std::shared_ptr& g, bool disable_cse) { passes::UnpackHardSwish(g); torch::jit::EliminateRedundantGuards(g); torch::jit::RemoveListMutation(g); @@ -42,9 +42,9 @@ void LowerGraph(std::shared_ptr& g) { passes::Conv3DToConvolution(g); passes::FuseAddMMBranches(g); passes::RemoveBNDimCheck(g); - LOG_INFO("====PRE CSE =====" << *g); - // torch::jit::EliminateCommonSubexpression(g); - LOG_INFO("====POST CSE =====" << *g); + if (!disable_cse) { + torch::jit::EliminateCommonSubexpression(g); + } // torch::jit::UnrollLoops(g); passes::UnpackAddMM(g); // passes::UnpackBatchNorm(g); @@ -57,25 +57,36 @@ void LowerGraph(std::shared_ptr& g) { } torch::jit::Module LowerModule(const torch::jit::script::Module& mod) { + LOG_DEBUG("Input module is being frozen by torch::jit::freeze_module"); auto mod_ = torch::jit::freeze_module(mod); return mod_; } std::pair, std::vector> Lower( const torch::jit::script::Module& mod, - std::string method_name) { - auto lowered_mod = mod; // LowerModule(mod); + std::string method_name, + bool unfreeze_module = false) { + auto lowered_mod = unfreeze_module ? mod : LowerModule(mod); auto g = lowered_mod.get_method(method_name).graph(); LOG_GRAPH(*g); // Go through TRTorch Lowering to reformat graph to be conversion friendly // and also segment for accelerators and executors (TRT-DLA, TRT-GPU, PYT) - LOG_GRAPH("TRTorch Graph Lowering"); - // lowering::LowerGraph(g); + // unfreeze_module is used to not perform constant folding on weights in the network. + // In quantization aware trained (QAT) models, weights are passed through quantize and + // dequantize nodes which should not be folded. So unfreeze_module is set to True for QAT models. + if (!unfreeze_module) { + LOG_GRAPH("TRTorch Graph Lowering"); + lowering::LowerGraph(g, false); + } LOG_GRAPH("LibTorch Lowering"); auto graph_and_ivalues = torch::jit::LowerGraph(*g, lowered_mod._ivalue()); - lowering::LowerGraph(graph_and_ivalues.first); + + if (unfreeze_module) { + LOG_GRAPH("TRTorch Graph Lowering"); + lowering::LowerGraph(graph_and_ivalues.first, true); + } // Is this necessary? lowering::LowerBlock(g->block()); diff --git a/core/lowering/lowering.h b/core/lowering/lowering.h index 119e082848..fd82fd0852 100644 --- a/core/lowering/lowering.h +++ b/core/lowering/lowering.h @@ -7,11 +7,12 @@ namespace core { namespace lowering { void LowerBlock(torch::jit::Block* b); -void LowerGraph(std::shared_ptr& g); +void LowerGraph(std::shared_ptr& g, bool disable_cse /*=false*/); torch::jit::Module LowerModule(const torch::jit::script::Module& mod); std::pair, std::vector> Lower( const torch::jit::script::Module& mod, - std::string method_name); + std::string method_name, + bool unfreeze_module /*=false*/); } // namespace lowering } // namespace core diff --git a/cpp/api/include/trtorch/trtorch.h b/cpp/api/include/trtorch/trtorch.h index c50b3909e6..f4e834ba9d 100644 --- a/cpp/api/include/trtorch/trtorch.h +++ b/cpp/api/include/trtorch/trtorch.h @@ -262,9 +262,9 @@ struct TRTORCH_API CompileSpec { * Emum for selecting engine capability */ enum class EngineCapability : int8_t { - kDEFAULT, - kSAFE_GPU, - kSAFE_DLA, + kSTANDARD, + kSAFETY, + kDLA_STANDALONE, }; class TRTORCH_API TensorFormat { @@ -686,12 +686,12 @@ struct TRTORCH_API CompileSpec { * This is the behavior of FP32 layers by default. */ bool disable_tf32 = false; - - /** + + /** * Enable sparsity for weights of conv and FC layers */ bool sparse_weights = false; - + /** * Build a refitable engine */ diff --git a/cpp/api/src/compile_spec.cpp b/cpp/api/src/compile_spec.cpp index 5b295f1579..63e7a7082f 100644 --- a/cpp/api/src/compile_spec.cpp +++ b/cpp/api/src/compile_spec.cpp @@ -405,7 +405,13 @@ core::CompileSpec to_internal_compile_spec(CompileSpec external) { if (internal.convert_info.engine_settings.enabled_precisions.find(nvinfer1::DataType::kINT8) != internal.convert_info.engine_settings.enabled_precisions.end()) { - internal.convert_info.engine_settings.calibrator = external.ptq_calibrator; + if (external.ptq_calibrator) { + internal.convert_info.engine_settings.calibrator = external.ptq_calibrator; + } else { + ; + internal.convert_info.engine_settings.unfreeze_module = true; + internal.convert_info.engine_settings.calibrator = nullptr; + } } else { internal.convert_info.engine_settings.calibrator = nullptr; } diff --git a/py/trtorch/csrc/register_tensorrt_classes.cpp b/py/trtorch/csrc/register_tensorrt_classes.cpp index 257d70cf09..bb00be268c 100644 --- a/py/trtorch/csrc/register_tensorrt_classes.cpp +++ b/py/trtorch/csrc/register_tensorrt_classes.cpp @@ -47,7 +47,7 @@ void RegisterTRTCompileSpec() { .def("_set_torch_fallback", &trtorch::pyapi::CompileSpec::setTorchFallbackIntrusive) .def("_set_ptq_calibrator", &trtorch::pyapi::CompileSpec::setPTQCalibratorViaHandle) .def("__str__", &trtorch::pyapi::CompileSpec::stringify); - + ADD_FIELD_GET_SET_REGISTRATION(TRTCompileSpecTSRegistration, trtorch::pyapi::CompileSpec, sparse_weights); ADD_FIELD_GET_SET_REGISTRATION(TRTCompileSpecTSRegistration, trtorch::pyapi::CompileSpec, disable_tf32); ADD_FIELD_GET_SET_REGISTRATION(TRTCompileSpecTSRegistration, trtorch::pyapi::CompileSpec, refit); diff --git a/py/trtorch/csrc/tensorrt_backend.cpp b/py/trtorch/csrc/tensorrt_backend.cpp index b3f2438786..da1600818c 100644 --- a/py/trtorch/csrc/tensorrt_backend.cpp +++ b/py/trtorch/csrc/tensorrt_backend.cpp @@ -32,7 +32,7 @@ c10::impl::GenericDict TensorRTBackend::compile(c10::IValue mod_val, c10::impl:: const auto& method_name = it->key(); auto method = mod.get_method(method_name); auto graph = method.graph(); - core::lowering::LowerGraph(graph); + core::lowering::LowerGraph(graph, false); } auto handles = c10::impl::GenericDict( diff --git a/py/trtorch/csrc/tensorrt_classes.cpp b/py/trtorch/csrc/tensorrt_classes.cpp index 0500936ce4..4d41d490c4 100644 --- a/py/trtorch/csrc/tensorrt_classes.cpp +++ b/py/trtorch/csrc/tensorrt_classes.cpp @@ -181,8 +181,15 @@ core::CompileSpec CompileSpec::toInternalCompileSpec() { for (auto p : enabled_precisions) { info.convert_info.engine_settings.enabled_precisions.insert(toTRTDataType(p)); } - - info.convert_info.engine_settings.calibrator = ptq_calibrator; + if (ptq_calibrator) { + info.convert_info.engine_settings.calibrator = ptq_calibrator; + } else { + if (info.convert_info.engine_settings.enabled_precisions.find(nvinfer1::DataType::kINT8) != + info.convert_info.engine_settings.enabled_precisions.end()) { + std::cout << "===INTERNAL UNFREEZE MODULE TRUE===" << std::endl; + info.convert_info.engine_settings.unfreeze_module = true; + } + } info.convert_info.engine_settings.sparse_weights = sparse_weights; info.convert_info.engine_settings.disable_tf32 = disable_tf32; info.convert_info.engine_settings.refit = refit; diff --git a/tests/core/conversion/converters/test_shuffle.cpp b/tests/core/conversion/converters/test_shuffle.cpp index 901dccf9c4..0b99d4147e 100644 --- a/tests/core/conversion/converters/test_shuffle.cpp +++ b/tests/core/conversion/converters/test_shuffle.cpp @@ -241,6 +241,30 @@ TEST(Converters, ATenTransposeConvertsCorrectly) { ASSERT_TRUE(trtorch::tests::util::almostEqual(jit_results[0], trt, 2e-6)); } +TEST(Converters, ATenTConvertsCorrectly) { + const auto graph = R"IR( + graph(%x.1 : Tensor): + %out : Tensor = aten::t(%x.1) + return (%out))IR"; + + auto g = std::make_shared(); + torch::jit::parseIR(graph, &*g); + + auto in = at::randint(0, 5, {3, 4}, {at::kCUDA}); + auto params = trtorch::core::conversion::get_named_params(g->inputs(), {}); + + std::cout << "Running JIT" << std::endl; + auto jit_results = trtorch::tests::util::RunGraph(g, params, {in}); + + std::cout << "Running TRT" << std::endl; + in = at::clone(in); + params = trtorch::core::conversion::get_named_params(g->inputs(), {}); + auto trt_results = trtorch::tests::util::RunGraphEngine(g, params, {in}); + auto trt = trt_results[0].reshape_as(jit_results[0]); + + ASSERT_TRUE(trtorch::tests::util::almostEqual(jit_results[0], trt, 2e-6)); +} + TEST(Converters, ATenTransposeNegativeConvertsCorrectly) { const auto graph = R"IR( graph(%x.1 : Tensor): @@ -312,7 +336,6 @@ TEST(Converters, ATenPixelShuffle3DConvertsCorrectly) { in = at::clone(in); params = trtorch::core::conversion::get_named_params(g->inputs(), {}); auto trt_results = trtorch::tests::util::RunGraphEngine(g, params, {in}); - // auto trt = trt_results[0].reshape_as(jit_results[0]); ASSERT_TRUE(trtorch::tests::util::almostEqual(jit_results[0], trt_results[0], 2e-6)); } diff --git a/tests/modules/hub.py b/tests/modules/hub.py index 239fc5c2e9..bc9a81b070 100644 --- a/tests/modules/hub.py +++ b/tests/modules/hub.py @@ -54,18 +54,10 @@ "model": torch.hub.load('pytorch/vision:v0.9.0', 'resnet50', pretrained=True), "path": "both" }, - "fcn_resnet101": { - "model": torch.hub.load('pytorch/vision:v0.9.0', 'fcn_resnet101', pretrained=True), - "path": "script" - }, "ssd": { "model": torch.hub.load('NVIDIA/DeepLearningExamples:torchhub', 'nvidia_ssd', model_math="fp32"), "path": "trace" }, - "faster_rcnn": { - "model": models.detection.fasterrcnn_resnet50_fpn(pretrained=True), - "path": "script" - }, "efficientnet_b0": { "model": timm.create_model('efficientnet_b0', pretrained=True), "path": "script" diff --git a/tests/util/run_graph_engine.cpp b/tests/util/run_graph_engine.cpp index e58aa85912..d8a8058240 100644 --- a/tests/util/run_graph_engine.cpp +++ b/tests/util/run_graph_engine.cpp @@ -69,7 +69,7 @@ std::vector RunGraphEngine( auto in = toInputs(inputs); auto info = core::conversion::ConversionInfo(in); info.engine_settings.workspace_size = 1 << 20; - info.engine_settings.op_precision = op_precision; + info.engine_settings.enabled_precisions.insert(op_precision); std::string eng = core::conversion::ConvertBlockToEngine(g->block(), info, named_params); return RunEngine(eng, inputs); }