From a4e55da79e8e6361f2d30e160d4c043f42191036 Mon Sep 17 00:00:00 2001 From: Michael Feliz Date: Tue, 4 Apr 2023 14:42:03 -0700 Subject: [PATCH] add support for aten::baddbmm --- .../converters/impl/matrix_multiply.cpp | 79 +++++++++++++++++++ .../converters/test_matrix_multiply.cpp | 69 ++++++++++++++++ 2 files changed, 148 insertions(+) diff --git a/core/conversion/converters/impl/matrix_multiply.cpp b/core/conversion/converters/impl/matrix_multiply.cpp index ec5703cd37..c4b12da810 100644 --- a/core/conversion/converters/impl/matrix_multiply.cpp +++ b/core/conversion/converters/impl/matrix_multiply.cpp @@ -72,6 +72,85 @@ auto mm_registrations TORCHTRT_UNUSED = mm_layer->setName(util::node_info(n).c_str()); auto out_tensor = ctx->AssociateValueAndTensor(n->outputs()[0], mm_layer->getOutput(0)); + LOG_DEBUG("Output tensor shape: " << out_tensor->getDimensions()); + return true; + }}) + .pattern( + {"aten::baddbmm(Tensor self, Tensor batch1, Tensor batch2, *, Scalar beta=1, Scalar alpha=1) -> Tensor", + [](ConversionCtx* ctx, const torch::jit::Node* n, args& args) -> bool { + auto self = args[0].ITensorOrFreeze(ctx); + auto bat1 = args[1].ITensorOrFreeze(ctx); + auto bat2 = args[2].ITensorOrFreeze(ctx); + nvinfer1::Dims batch1Dims = bat1->getDimensions(); + nvinfer1::Dims batch2Dims = bat2->getDimensions(); + + // check dimensions + TORCHTRT_CHECK( + batch1Dims.nbDims == 3, + "Expected 3-dimensional tensor, but got " + << batch1Dims.nbDims + << "-dimensional tensor for argument 'batch1' (while checking arguments for baddbmm)"); + TORCHTRT_CHECK( + batch2Dims.nbDims == 3, + "Expected 3-dimensional tensor, but got " + << batch2Dims.nbDims + << "-dimensional tensor for argument 'batch2' (while checking arguments for baddbmm)"); + TORCHTRT_CHECK( + batch1Dims.d[0] == batch2Dims.d[0], + "Expected tensor to have size " << batch1Dims.d[0] << " at dimension 0, but got size " + << batch2Dims.d[0] + << " for argument 'batch2' (while checking arguments for baddbmm)"); + TORCHTRT_CHECK( + batch1Dims.d[2] == batch2Dims.d[1], + "Expected tensor to have size " << batch1Dims.d[2] << " at dimension 1, but got size " + << batch2Dims.d[1] + << " for argument 'batch2' (while checking arguments for baddbmm)"); + + auto mm_layer = ctx->net->addMatrixMultiply( + *bat1, nvinfer1::MatrixOperation::kNONE, *bat2, nvinfer1::MatrixOperation::kNONE); + TORCHTRT_CHECK(mm_layer, "Unable to create matrix multiplication for node: " << *n); + mm_layer->setName((util::node_info(n) + "_matmul").c_str()); + + auto mm_out = mm_layer->getOutput(0); + + auto alpha = args[4].unwrapToScalar(); + if (alpha.to() != 1.) { + auto alpha_tensor = scalar_to_tensor(ctx, alpha); + auto alpha_layer = add_elementwise( + ctx, + nvinfer1::ElementWiseOperation::kPROD, + mm_out, + alpha_tensor, + util::node_info(n) + std::string("_alpha_mul")); + TORCHTRT_CHECK(alpha_layer, "Unable to create alpha_mul layer from node: " << *n); + mm_out = alpha_layer->getOutput(0); + } + + auto beta = args[3].unwrapToScalar(); + // If beta is 0, then input will be ignored, and nan and inf in it will not be propagated. + if (beta.to() != 0.) { + if (beta.to() != 1.) { + auto beta_tensor = scalar_to_tensor(ctx, beta); + auto beta_layer = add_elementwise( + ctx, + nvinfer1::ElementWiseOperation::kPROD, + self, + beta_tensor, + util::node_info(n) + std::string("_beta_mul")); + TORCHTRT_CHECK(beta_layer, "Unable to create beta_mul layer from node: " << *n); + self = beta_layer->getOutput(0); + } + auto self_add_layer = add_elementwise( + ctx, + nvinfer1::ElementWiseOperation::kSUM, + self, + mm_out, + util::node_info(n) + std::string("_self_add")); + TORCHTRT_CHECK(self_add_layer, "Unable to create self_add layer from node: " << *n); + mm_out = self_add_layer->getOutput(0); + } + + auto out_tensor = ctx->AssociateValueAndTensor(n->outputs()[0], mm_out); LOG_DEBUG("Output tensor shape: " << out_tensor->getDimensions()); return true; }}); diff --git a/tests/core/conversion/converters/test_matrix_multiply.cpp b/tests/core/conversion/converters/test_matrix_multiply.cpp index 4f5c726fd4..9c84ba22f6 100644 --- a/tests/core/conversion/converters/test_matrix_multiply.cpp +++ b/tests/core/conversion/converters/test_matrix_multiply.cpp @@ -67,3 +67,72 @@ TEST(Converters, ATenBMMConvertsCorrectly) { ASSERT_TRUE(torch_tensorrt::tests::util::almostEqual(jit_results[0], trt, 2e-6)); } + +TEST(Converters, ATenBADDBMMConvertsCorrectly) { + const auto graph = R"IR( + graph(%0 : Tensor, %1 : Tensor, %2 : Tensor): + %a : float = prim::Constant[value=1.5]() + %b : float = prim::Constant[value=.2]() + %2 : Tensor = aten::baddbmm(%0, %1, %2, %b, %a) + return (%2))IR"; + + auto g = std::make_shared(); + torch::jit::parseIR(graph, g.get()); + + auto self = at::randn({10, 3, 5}, {at::kCUDA}); + auto bat1 = at::randn({10, 3, 4}, {at::kCUDA}); + auto bat2 = at::randn({10, 4, 5}, {at::kCUDA}); + auto params = torch_tensorrt::core::ir::get_static_params(g->inputs(), {}); + auto jit_results = torch_tensorrt::tests::util::RunGraph(g, params, {self, bat1, bat2}); + + params = torch_tensorrt::core::ir::get_static_params(g->inputs(), {}); + auto trt_results = torch_tensorrt::tests::util::RunGraphEngine(g, params, {self, bat1, bat2}); + + ASSERT_TRUE(torch_tensorrt::tests::util::almostEqual(jit_results[0], trt_results[0], 2e-6)); +} + +TEST(Converters, ATenBADDBMMAlphaBetaDisabledConvertsCorrectly) { + const auto graph = R"IR( + graph(%0 : Tensor, %1 : Tensor, %2 : Tensor): + %a : float = prim::Constant[value=1]() + %b : float = prim::Constant[value=0]() + %2 : Tensor = aten::baddbmm(%0, %1, %2, %b, %a) + return (%2))IR"; + + auto g = std::make_shared(); + torch::jit::parseIR(graph, g.get()); + + auto self = at::randn({10, 3, 5}, {at::kCUDA}); + auto bat1 = at::randn({10, 3, 4}, {at::kCUDA}); + auto bat2 = at::randn({10, 4, 5}, {at::kCUDA}); + auto params = torch_tensorrt::core::ir::get_static_params(g->inputs(), {}); + auto jit_results = torch_tensorrt::tests::util::RunGraph(g, params, {self, bat1, bat2}); + + params = torch_tensorrt::core::ir::get_static_params(g->inputs(), {}); + auto trt_results = torch_tensorrt::tests::util::RunGraphEngine(g, params, {self, bat1, bat2}); + + ASSERT_TRUE(torch_tensorrt::tests::util::almostEqual(jit_results[0], trt_results[0], 2e-6)); +} + +TEST(Converters, ATenBADDBMMScalarDefaultsConvertsCorrectly) { + const auto graph = R"IR( + graph(%0 : Tensor, %1 : Tensor, %2 : Tensor): + %a : float = prim::Constant[value=1]() + %b : float = prim::Constant[value=1]() + %2 : Tensor = aten::baddbmm(%0, %1, %2, %b, %a) + return (%2))IR"; + + auto g = std::make_shared(); + torch::jit::parseIR(graph, g.get()); + + auto self = at::randn({10, 3, 5}, {at::kCUDA}); + auto bat1 = at::randn({10, 3, 4}, {at::kCUDA}); + auto bat2 = at::randn({10, 4, 5}, {at::kCUDA}); + auto params = torch_tensorrt::core::ir::get_static_params(g->inputs(), {}); + auto jit_results = torch_tensorrt::tests::util::RunGraph(g, params, {self, bat1, bat2}); + + params = torch_tensorrt::core::ir::get_static_params(g->inputs(), {}); + auto trt_results = torch_tensorrt::tests::util::RunGraphEngine(g, params, {self, bat1, bat2}); + + ASSERT_TRUE(torch_tensorrt::tests::util::almostEqual(jit_results[0], trt_results[0], 2e-6)); +}