From 64618721f2e49ca5e0fe6fa096137b2364fae119 Mon Sep 17 00:00:00 2001 From: Naren Dasan Date: Fri, 8 May 2020 00:14:41 -0700 Subject: [PATCH] fix(aten::batch_norm): A new batch norm implementation that hopefully doesnt have the same performace cost Signed-off-by: Naren Dasan Signed-off-by: Naren Dasan --- .../conversion/converters/impl/batch_norm.cpp | 141 +++++++----------- core/lowering/lowering.cpp | 1 + core/lowering/passes/unpack_batch_norm.cpp | 3 + tests/core/converters/BUILD | 5 + tests/core/converters/test_batch_norm.cpp | 36 +++++ 5 files changed, 99 insertions(+), 87 deletions(-) create mode 100644 tests/core/converters/test_batch_norm.cpp diff --git a/core/conversion/converters/impl/batch_norm.cpp b/core/conversion/converters/impl/batch_norm.cpp index f4fcdc9e29..bd923310a0 100644 --- a/core/conversion/converters/impl/batch_norm.cpp +++ b/core/conversion/converters/impl/batch_norm.cpp @@ -1,3 +1,4 @@ +#include "torch/torch.h" #include "core/util/prelude.h" #include "core/conversion/converters/converters.h" @@ -8,93 +9,59 @@ namespace converters { namespace impl { namespace { -bool ConvertConvBatchNorm(ConversionCtx* ctx, const torch::jit::Node* n, args& args) { - auto input = args[0].ITensor(); - auto shape = util::toVec(input->getDimensions()); - LOG_WARNING("Assuming channel dimension is 3 because input is from a conv layer, please verify"); - auto gamma = args[1].unwrapToTensor(at::full({shape[shape.size() - 3]}, 1)); - auto beta = args[2].unwrapToTensor(at::full({shape[shape.size() - 3]}, 1)); - auto mean = args[3].unwrapToTensor(at::full({shape[shape.size() - 3]}, 0)); - auto var = args[4].unwrapToTensor(at::full({shape[shape.size() - 3]}, 0)); - LOG_WARNING("Momentum argument is disregarded"); - //auto momentum = args[6].unwrapToDouble(0); - auto eps = args[7].unwrapToDouble(0); - - auto w = at::diag(gamma / at::sqrt(var + eps)); - auto w_shape = w.sizes().vec(); - w_shape.push_back(1); - w_shape.push_back(1); - w = w.reshape(w_shape); - auto b = beta - gamma * (mean / at::sqrt(var + eps)); - - auto weights = Weights(ctx, w); - auto bias = Weights(ctx, b); - - auto bn_as_conv = ctx->net->addConvolutionNd(*input, weights.num_output_maps, weights.kernel_shape, weights.data, bias.data); - TRTORCH_CHECK(bn_as_conv, "Unable to create fused batch norm from node: " << *n); - - bn_as_conv->setName(util::node_info(n).c_str()); - - auto bn_out = ctx->AssociateValueAndTensor(n->outputs()[0], bn_as_conv->getOutput(0)); - LOG_DEBUG("Output tensor shape: " << bn_out->getDimensions()); - return true; -} - -bool ConvertLinearBatchNorm(ConversionCtx* ctx, const torch::jit::Node* n, args& args) { - auto input = args[0].ITensor(); - auto shape = util::toVec(input->getDimensions()); - auto gamma = args[1].unwrapToTensor(at::full({shape},1)); - auto beta = args[2].unwrapToTensor(at::full({shape},1)); - auto mean = args[3].unwrapToTensor(at::full({shape},0)); - auto var = args[4].unwrapToTensor(at::full({shape},0)); - LOG_WARNING("Momentum argument is disregarded"); - //auto momentum = args[6].unwrapToDouble(0); - auto eps = args[7].unwrapToDouble(0); - - auto mean_ = tensor_to_const(ctx, mean); - auto bot_half = at::sqrt(var + eps); - auto bot_half_ = tensor_to_const(ctx, bot_half); - auto gamma_ = tensor_to_const(ctx, gamma); - auto beta_ = tensor_to_const(ctx, beta); - - auto top_half = ctx->net->addElementWise(*input, *mean_, nvinfer1::ElementWiseOperation::kSUB); - auto top_half_out = top_half->getOutput(0); - auto x_hat = ctx->net->addElementWise(*top_half_out, *bot_half_, nvinfer1::ElementWiseOperation::kDIV); - auto x_hat_out = x_hat->getOutput(0); - auto bn_scaled = ctx->net->addElementWise(*gamma_, *x_hat_out, nvinfer1::ElementWiseOperation::kPROD); - auto bn_scaled_out = bn_scaled->getOutput(0); - auto bn_biased = ctx->net->addElementWise(*beta_, *bn_scaled_out, nvinfer1::ElementWiseOperation::kSUM); - auto bn_biased_out = bn_biased->getOutput(0); - - bn_biased->setName(util::node_info(n).c_str()); - ctx->AssociateValueAndTensor(n->outputs()[0], bn_biased_out); - - return true; -} - -volatile auto batch_norm_registrations = RegisterNodeConversionPatterns() - .pattern({ - R"SIG(aten::batch_norm(Tensor input, Tensor? gamma, Tensor? beta, - Tensor? mean, Tensor? var, - bool training, float momentum, float eps, bool cudnn_enabled) -> (Tensor))SIG", - [](ConversionCtx* ctx, const torch::jit::Node* n, args& args) -> bool { - auto input = args[0].ITensor(); - auto shape = input->getDimensions(); - auto gamma = args[1].unwrapToTensor(); - - if (/*training*/ args[5].unwrapToBool()) { - LOG_WARNING(R"WARN(TRTorch only converts forward pass of graphs, but saw training = True, may see - unexpected behavior, consider placing module in eval mode before exporting the TorchScript module)WARN"); - } - - // If gamma is None this fails - if (util::volume(shape) == gamma.numel()) { - return ConvertLinearBatchNorm(ctx, n, args); - } else { - return ConvertConvBatchNorm(ctx, n, args); - } - } - }); +auto batch_norm_registrations TRTORCH_UNUSED = RegisterNodeConversionPatterns() + .pattern({ + R"SIG(aten::batch_norm(Tensor input, Tensor? gamma, Tensor? beta, + Tensor? mean, Tensor? var, + bool training, float momentum, float eps, bool cudnn_enabled) -> (Tensor))SIG", + [](ConversionCtx* ctx, const torch::jit::Node* n, args& args) -> bool { + auto input = args[0].ITensor(); + auto orig_shape = input->getDimensions(); + auto shape = util::toVec(orig_shape); + auto options = torch::TensorOptions().dtype(torch::kFloat32); + auto gamma = args[1].unwrapToTensor(at::full({shape}, 1, {options})); + auto beta = args[2].unwrapToTensor(at::full({shape}, 1, {options})); + auto mean = args[3].unwrapToTensor(at::full({shape}, 0, {options})); + auto var = args[4].unwrapToTensor(at::full({shape}, 0, {options})); + auto eps = args[7].unwrapToDouble(1e-5f); + + LOG_DEBUG("momentum disregarded"); + LOG_DEBUG("training disregarded"); + LOG_DEBUG("cudnn disregarded"); + + auto should_unpack = util::toVec(orig_shape).size() < 4; + if (should_unpack) { + // expand spatial dims from 1D to 2D + auto new_shape = util::toDimsPad(util::toVec(orig_shape), 4); + LOG_DEBUG("Input shape is less than 4D got: " << orig_shape << ", inserting shuffle layer to reshape to 4D tensor shape: " << new_shape); + auto in_shuffle = ctx->net->addShuffle(*input); + in_shuffle->setReshapeDimensions(new_shape); + in_shuffle->setName(std::string("[Reshape input to " + util::toStr(new_shape) + ']').c_str()); + input = in_shuffle->getOutput(0); + } + + auto scale = gamma / torch::sqrt(var + eps); + auto bias = beta - mean * scale; + + auto scale_weights = Weights(ctx, scale); + auto bias_weights = Weights(ctx, bias); + + auto bn = ctx->net->addScaleNd(*input, nvinfer1::ScaleMode::kCHANNEL, bias_weights.data, scale_weights.data, {}, 1); + bn->setName(util::node_info(n).c_str()); + auto out_tensor = bn->getOutput(0); + + if (should_unpack) { + LOG_DEBUG("Inserting shuffle layer to reshape to back to original shape: " << orig_shape); + auto out_shuffle = ctx->net->addShuffle(*out_tensor); + out_shuffle->setReshapeDimensions(orig_shape); + out_shuffle->setName(std::string("[Reshape output to " + util::toStr(orig_shape) + ']').c_str()); + out_tensor = out_shuffle->getOutput(0); + } + + ctx->AssociateValueAndTensor(n->outputs()[0], out_tensor); + return true; + } + }); } // namespace diff --git a/core/lowering/lowering.cpp b/core/lowering/lowering.cpp index ef67374770..036b8f50d6 100644 --- a/core/lowering/lowering.cpp +++ b/core/lowering/lowering.cpp @@ -27,6 +27,7 @@ void LowerGraph(std::shared_ptr& g) { passes::FuseFlattenLinear(g); passes::Conv2DToConvolution(g); passes::UnpackAddMM(g); + //passes::UnpackBatchNorm(g); passes::UnpackLogSoftmax(g); //passes::RemoveDimExeception(g); //irfusers::UnpackBatchNorm(g); diff --git a/core/lowering/passes/unpack_batch_norm.cpp b/core/lowering/passes/unpack_batch_norm.cpp index 19a9ad29bd..13f2e710f8 100644 --- a/core/lowering/passes/unpack_batch_norm.cpp +++ b/core/lowering/passes/unpack_batch_norm.cpp @@ -41,6 +41,9 @@ void UnpackBatchNorm(std::shared_ptr& graph) { torch::jit::SubgraphRewriter unpack_batch_norm; unpack_batch_norm.RegisterRewritePattern(batch_norm_pattern, expanded_batch_norm_pattern); unpack_batch_norm.runOnGraph(graph); + LOG_DEBUG("[Lowering Batch Norm]: momentum disregarded"); + LOG_DEBUG("[Lowering Batch Norm]: training disregarded"); + LOG_DEBUG("[Lowering Batch Norm]: cudnn disregarded"); LOG_GRAPH("Post unpack batchnorm: " << *graph); } } // Namespace passes diff --git a/tests/core/converters/BUILD b/tests/core/converters/BUILD index b47692a886..e3973a38fc 100644 --- a/tests/core/converters/BUILD +++ b/tests/core/converters/BUILD @@ -4,6 +4,10 @@ converter_test( name = "test_activation" ) +converter_test( + name = "test_batch_norm" +) + converter_test( name = "test_conv" ) @@ -44,6 +48,7 @@ test_suite( name = "test_converters", tests = [ ":test_activation", + ":test_batch_norm", ":test_conv", ":test_element_wise", ":test_linear", diff --git a/tests/core/converters/test_batch_norm.cpp b/tests/core/converters/test_batch_norm.cpp new file mode 100644 index 0000000000..727d7a17ea --- /dev/null +++ b/tests/core/converters/test_batch_norm.cpp @@ -0,0 +1,36 @@ +#include +#include "gtest/gtest.h" +#include "torch/csrc/jit/ir/irparser.h" +#include "tests/util/util.h" +#include "core/compiler.h" + +TEST(Converters, ATenBatchNormConvertsCorrectly) { + const auto graph = R"IR( + graph(%0 : Tensor, + %1: Float(5), + %2: Float(5), + %3: Float(5), + %4: Float(5)): + %5 : bool = prim::Constant[value=0]() + %6 : float = prim::Constant[value=1.0000000000000001e-05]() + %7 : float = prim::Constant[value=0.10000000000000001]() + %8 : Tensor = aten::batch_norm(%0, %1, %2, %3, %4, %5, %6, %7, %5) + return (%8))IR"; + + auto g = std::make_shared(); + torch::jit::parseIR(graph, &*g); + + auto in = at::randint(1, 10, {1, 5, 5, 5}, {at::kCUDA}); + auto gamma = at::randint(1, 10, {5}, {at::kCUDA}); + auto beta = at::randint(1, 10, {5}, {at::kCUDA}); + auto mean = at::randint(1, 10, {5}, {at::kCUDA}); + auto var = at::randint(1, 10, {5}, {at::kCUDA}); + + auto params = trtorch::core::conversion::get_named_params(g->inputs(), {gamma, beta, mean, var}); + auto jit_results = trtorch::tests::util::RunGraph(g, params, {in}); + + params = trtorch::core::conversion::get_named_params(g->inputs(), {gamma, beta, mean, var}); + auto trt_results = trtorch::tests::util::RunGraphEngine(g, params, {in}); + + ASSERT_TRUE(trtorch::tests::util::almostEqual(jit_results[0], trt_results[0].reshape_as(jit_results[0]), 2e-6)); +}