From c8dc6e950d82b85b523a42269b236d09121f8721 Mon Sep 17 00:00:00 2001 From: Ruoqian Guo Date: Tue, 19 Oct 2021 11:46:36 +0000 Subject: [PATCH] feat: support aten::conv1d and aten::conv_transpose1d Signed-off-by: Ruoqian Guo --- .../converters/impl/conv_deconv.cpp | 85 +++++++++++++----- .../converters/test_conv_deconv.cpp | 86 +++++++++++++++++++ 2 files changed, 148 insertions(+), 23 deletions(-) diff --git a/core/conversion/converters/impl/conv_deconv.cpp b/core/conversion/converters/impl/conv_deconv.cpp index 4482618626..4a15215b93 100644 --- a/core/conversion/converters/impl/conv_deconv.cpp +++ b/core/conversion/converters/impl/conv_deconv.cpp @@ -10,18 +10,19 @@ namespace converters { namespace impl { namespace { -bool add_conv_deconv(ConversionCtx* ctx, const torch::jit::Node* n, args& args) { +bool add_conv_deconv( + ConversionCtx* ctx, + const torch::jit::Node* n, + args& args, + nvinfer1::Dims& stride, + nvinfer1::Dims& padding, + nvinfer1::Dims& dilation, + bool transposed, + nvinfer1::Dims& out_padding, + int64_t groups) { // Input to conv/deconv auto in = args[0].ITensor(); - // Conv /deconv parameters - auto stride = util::toDims(args[3].unwrapToIntList()); - auto padding = util::toDims(args[4].unwrapToIntList()); - auto dilation = util::toDims(args[5].unwrapToIntList()); - bool transposed = args[6].unwrapToBool(); - auto out_padding = util::toDims(args[7].unwrapToIntList()); - int64_t groups = args[8].unwrapToInt(); - // Reshape the parameters to 2D if needed if (stride.nbDims == 1) { stride = util::unsqueezeDims(stride, 1, 1); @@ -174,28 +175,66 @@ bool add_conv_deconv(ConversionCtx* ctx, const torch::jit::Node* n, args& args) return true; } -auto conv_registrations TRTORCH_UNUSED = RegisterNodeConversionPatterns() - .pattern({ - R"SIG(aten::_convolution(Tensor input, Tensor weight, +auto conv_registrations TRTORCH_UNUSED = + RegisterNodeConversionPatterns() + .pattern({ + R"SIG(aten::_convolution(Tensor input, Tensor weight, Tensor? bias, int[] stride, int[] padding, int[] dilation, bool transposed, int[] output_padding, int groups, bool benchmark, bool deterministic, bool cudnn_enabled, bool allow_tf32) -> (Tensor))SIG", - [](ConversionCtx* ctx, const torch::jit::Node* n, args& args) -> bool { - return add_conv_deconv(ctx, n, args); - }}) - .pattern({ - R"SIG(aten::_convolution.deprecated(Tensor input, Tensor weight, + [](ConversionCtx* ctx, const torch::jit::Node* n, args& args) -> bool { + // Conv /deconv parameters + auto stride = util::toDims(args[3].unwrapToIntList()); + auto padding = util::toDims(args[4].unwrapToIntList()); + auto dilation = util::toDims(args[5].unwrapToIntList()); + bool transposed = args[6].unwrapToBool(); + auto out_padding = util::toDims(args[7].unwrapToIntList()); + int64_t groups = args[8].unwrapToInt(); + return add_conv_deconv(ctx, n, args, stride, padding, dilation, transposed, out_padding, groups); + }}) + .pattern({ + R"SIG(aten::_convolution.deprecated(Tensor input, Tensor weight, Tensor? bias, int[] stride, int[] padding, int[] dilation, bool transposed, int[] output_padding, int groups, bool benchmark, bool deterministic, bool cudnn_enabled) -> (Tensor))SIG", - [](ConversionCtx* ctx, const torch::jit::Node* n, args& args) -> bool { - // This pattern is only matched for traced JIT models which do not - // have allow_tf32 bool in the function signature. The TRT conversion - // code is exactly same as the above call. - return add_conv_deconv(ctx, n, args); - }}); + [](ConversionCtx* ctx, const torch::jit::Node* n, args& args) -> bool { + // This pattern is only matched for traced JIT models which do not + // have allow_tf32 bool in the function signature. The TRT conversion + // code is exactly same as the above call. + auto stride = util::toDims(args[3].unwrapToIntList()); + auto padding = util::toDims(args[4].unwrapToIntList()); + auto dilation = util::toDims(args[5].unwrapToIntList()); + bool transposed = args[6].unwrapToBool(); + auto out_padding = util::toDims(args[7].unwrapToIntList()); + int64_t groups = args[8].unwrapToInt(); + return add_conv_deconv(ctx, n, args, stride, padding, dilation, transposed, out_padding, groups); + }}) + .pattern( + {R"SIG(aten::conv1d(Tensor input, Tensor weight, Tensor? bias, int[] stride, int[] padding, int[] dilation, int groups) -> Tensor)SIG", + [](ConversionCtx* ctx, const torch::jit::Node* n, args& args) -> bool { + // Conv /deconv parameters + auto stride = util::toDims(args[3].unwrapToIntList()); + auto padding = util::toDims(args[4].unwrapToIntList()); + auto dilation = util::toDims(args[5].unwrapToIntList()); + bool transposed = false; + nvinfer1::Dims out_padding{1, {0}}; + int64_t groups = args[6].unwrapToInt(); + return add_conv_deconv(ctx, n, args, stride, padding, dilation, transposed, out_padding, groups); + }}) + .pattern( + {R"SIG(aten::conv_transpose1d(Tensor input, Tensor weight, Tensor? bias, int[] stride, int[] padding, int[] output_padding, int groups, int[] dilation) -> Tensor)SIG", + [](ConversionCtx* ctx, const torch::jit::Node* n, args& args) -> bool { + // Conv /deconv parameters + auto stride = util::toDims(args[3].unwrapToIntList()); + auto padding = util::toDims(args[4].unwrapToIntList()); + auto out_padding = util::toDims(args[5].unwrapToIntList()); + bool transposed = true; + int64_t groups = args[6].unwrapToInt(); + auto dilation = util::toDims(args[7].unwrapToIntList()); + return add_conv_deconv(ctx, n, args, stride, padding, dilation, transposed, out_padding, groups); + }}); } // namespace } // namespace impl } // namespace converters diff --git a/tests/core/conversion/converters/test_conv_deconv.cpp b/tests/core/conversion/converters/test_conv_deconv.cpp index 87dd6c2406..0dda798caf 100644 --- a/tests/core/conversion/converters/test_conv_deconv.cpp +++ b/tests/core/conversion/converters/test_conv_deconv.cpp @@ -10,6 +10,12 @@ // int[] output_padding, int groups, bool benchmark, // bool deterministic, bool cudnn_enabled) -> (Tensor) +// aten::conv1d(Tensor input, Tensor weight, Tensor? bias, int[] stride, int[] padding, int[] dilation, int groups) -> +// Tensor + +// aten::conv_transpose1d(Tensor input, Tensor weight, Tensor? bias, int[] stride, int[] padding, int[] output_padding, +// int groups, int[] dilation) -> Tensor + void conv_test_helper(std::string graph_ir) { auto g = std::make_shared(); torch::jit::parseIR(graph_ir, g.get()); @@ -116,6 +122,86 @@ TEST(Converters, ATenConvolution1dConvertsCorrectly) { ASSERT_TRUE(trtorch::tests::util::almostEqual(jit_results[0], trt, 2e-6)); } +TEST(Converters, ATenConv1dConvertsCorrectly) { + const auto graph = R"IR( + graph(%0 : Tensor, + %1 : Float(4, 3, 3, strides=[9, 3, 1]), + %2 : Float(3)): + %3 : int = prim::Constant[value=1]() + %4 : int = prim::Constant[value=0]() + %5 : int = prim::Constant[value=1]() + %8 : int[] = prim::ListConstruct(%3) + %9 : int[] = prim::ListConstruct(%4) + %10 : int[] = prim::ListConstruct(%5) + %12 : Tensor = aten::conv1d(%0, %1, %2, %8, %9, %10, %3) + return (%12))IR"; + + auto g = std::make_shared(); + torch::jit::parseIR(graph, g.get()); + + auto in = at::randint(1, 2, {1, 3, 3}, {at::kCUDA}); + auto w = at::randint(1, 2, {4, 3, 3}, {at::kCUDA}); + auto b = at::randint(1, 10, {4}, {at::kCUDA}); + + auto jit_in = at::clone(in); + auto jit_w = at::clone(w); + auto jit_b = at::clone(b); + + auto params = trtorch::core::conversion::get_named_params(g->inputs(), {jit_w, jit_b}); + auto jit_results = trtorch::tests::util::RunGraph(g, params, {jit_in}); + + auto trt_in = at::clone(in); + auto trt_w = at::clone(w); + auto trt_b = at::clone(b); + params = trtorch::core::conversion::get_named_params(g->inputs(), {trt_w, trt_b}); + auto trt_results = trtorch::tests::util::RunGraphEngine(g, params, {trt_in}); + + auto trt = trt_results[0].reshape(jit_results[0].sizes()); + + ASSERT_TRUE(trtorch::tests::util::almostEqual(jit_results[0], trt, 2e-6)); +} + +TEST(Converters, ATenConvTranspose1dConvertsCorrectly) { + const auto graph = R"IR( + graph(%0 : Tensor, + %1 : Float(4, 3, 3, strides=[9, 3, 1]), + %2 : Float(3)): + %3 : int = prim::Constant[value=1]() + %4 : int = prim::Constant[value=0]() + %5 : int = prim::Constant[value=1]() + %6 : int = prim::Constant[value=0]() + %8 : int[] = prim::ListConstruct(%3) + %9 : int[] = prim::ListConstruct(%4) + %10 : int[] = prim::ListConstruct(%5) + %11 : int[] = prim::ListConstruct(%6) + %12 : Tensor = aten::conv_transpose1d(%0, %1, %2, %8, %9, %11, %3, %10) + return (%12))IR"; + + auto g = std::make_shared(); + torch::jit::parseIR(graph, g.get()); + + auto in = at::randint(1, 2, {1, 8, 3}, {at::kCUDA}); + auto w = at::randint(1, 2, {8, 4, 3}, {at::kCUDA}); + auto b = at::randint(1, 10, {4}, {at::kCUDA}); + + auto jit_in = at::clone(in); + auto jit_w = at::clone(w); + auto jit_b = at::clone(b); + + auto params = trtorch::core::conversion::get_named_params(g->inputs(), {jit_w, jit_b}); + auto jit_results = trtorch::tests::util::RunGraph(g, params, {jit_in}); + + auto trt_in = at::clone(in); + auto trt_w = at::clone(w); + auto trt_b = at::clone(b); + params = trtorch::core::conversion::get_named_params(g->inputs(), {trt_w, trt_b}); + auto trt_results = trtorch::tests::util::RunGraphEngine(g, params, {trt_in}); + + auto trt = trt_results[0].reshape(jit_results[0].sizes()); + + ASSERT_TRUE(trtorch::tests::util::almostEqual(jit_results[0], trt, 2e-6)); +} + TEST(Converters, ATenConvolutionNoBiasConvertsCorrectly) { const auto graph = R"IR( graph(%0 : Tensor,