From c77604181d79cabe730139059b8dd791c62443c1 Mon Sep 17 00:00:00 2001 From: Trevor Morris Date: Thu, 5 Nov 2020 00:33:50 +0000 Subject: [PATCH] Support batch norm with rank <= 4 --- src/runtime/contrib/tensorrt/tensorrt_ops.h | 14 ++++++++++++++ 1 file changed, 14 insertions(+) diff --git a/src/runtime/contrib/tensorrt/tensorrt_ops.h b/src/runtime/contrib/tensorrt/tensorrt_ops.h index 58fc8d7acb65..9f1f0feabf97 100644 --- a/src/runtime/contrib/tensorrt/tensorrt_ops.h +++ b/src/runtime/contrib/tensorrt/tensorrt_ops.h @@ -528,6 +528,10 @@ class BatchNormOpConverter : public TrtOpConverter { CHECK_EQ(gamma.count, var.count); CHECK(bn_attr->axis == 1 || bn_attr->axis == 3); const bool need_transpose = bn_attr->axis == 3; + const int required_rank = TRT_HAS_IMPLICIT_BATCH(params) ? 3 : 4; + auto input_dims = TrtDimsToVector(input->getDimensions()); + CHECK(input_dims.size() > 0 && input_dims.size() <= required_rank); + const bool need_reshape = input_dims.size() != required_rank; void* weight_scale_ptr = new float[gamma.count]; nvinfer1::Weights weight_scale{nvinfer1::DataType::kFLOAT, weight_scale_ptr, gamma.count}; @@ -557,10 +561,20 @@ class BatchNormOpConverter : public TrtOpConverter { if (need_transpose) { input = Transpose(params, input, {0, 3, 1, 2}); } + if (need_reshape) { + // Add dims of size 1 until rank is required_rank. + std::vector new_shape(input_dims); + while (new_shape.size() < required_rank) new_shape.insert(new_shape.end(), 1); + input = Reshape(params, input, new_shape); + } nvinfer1::IScaleLayer* scale_layer = params->network->addScale( *input, nvinfer1::ScaleMode::kCHANNEL, weight_shift, weight_scale, power); CHECK(scale_layer != nullptr); auto output = scale_layer->getOutput(0); + if (need_reshape) { + // Remove added dims. + output = Reshape(params, output, input_dims); + } if (need_transpose) { output = Transpose(params, output, {0, 2, 3, 1}); }