Skip to content

Commit

Permalink
Merge pull request #1806 from mfeliz-cruise/michael.feliz/baddbmm
Browse files Browse the repository at this point in the history
feat: add support for aten::baddbmm
  • Loading branch information
peri044 authored Apr 6, 2023
2 parents 745af55 + a4e55da commit 78b571c
Show file tree
Hide file tree
Showing 2 changed files with 148 additions and 0 deletions.
79 changes: 79 additions & 0 deletions core/conversion/converters/impl/matrix_multiply.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,85 @@ auto mm_registrations TORCHTRT_UNUSED =
mm_layer->setName(util::node_info(n).c_str());
auto out_tensor = ctx->AssociateValueAndTensor(n->outputs()[0], mm_layer->getOutput(0));

LOG_DEBUG("Output tensor shape: " << out_tensor->getDimensions());
return true;
}})
.pattern(
{"aten::baddbmm(Tensor self, Tensor batch1, Tensor batch2, *, Scalar beta=1, Scalar alpha=1) -> Tensor",
[](ConversionCtx* ctx, const torch::jit::Node* n, args& args) -> bool {
auto self = args[0].ITensorOrFreeze(ctx);
auto bat1 = args[1].ITensorOrFreeze(ctx);
auto bat2 = args[2].ITensorOrFreeze(ctx);
nvinfer1::Dims batch1Dims = bat1->getDimensions();
nvinfer1::Dims batch2Dims = bat2->getDimensions();

// check dimensions
TORCHTRT_CHECK(
batch1Dims.nbDims == 3,
"Expected 3-dimensional tensor, but got "
<< batch1Dims.nbDims
<< "-dimensional tensor for argument 'batch1' (while checking arguments for baddbmm)");
TORCHTRT_CHECK(
batch2Dims.nbDims == 3,
"Expected 3-dimensional tensor, but got "
<< batch2Dims.nbDims
<< "-dimensional tensor for argument 'batch2' (while checking arguments for baddbmm)");
TORCHTRT_CHECK(
batch1Dims.d[0] == batch2Dims.d[0],
"Expected tensor to have size " << batch1Dims.d[0] << " at dimension 0, but got size "
<< batch2Dims.d[0]
<< " for argument 'batch2' (while checking arguments for baddbmm)");
TORCHTRT_CHECK(
batch1Dims.d[2] == batch2Dims.d[1],
"Expected tensor to have size " << batch1Dims.d[2] << " at dimension 1, but got size "
<< batch2Dims.d[1]
<< " for argument 'batch2' (while checking arguments for baddbmm)");

auto mm_layer = ctx->net->addMatrixMultiply(
*bat1, nvinfer1::MatrixOperation::kNONE, *bat2, nvinfer1::MatrixOperation::kNONE);
TORCHTRT_CHECK(mm_layer, "Unable to create matrix multiplication for node: " << *n);
mm_layer->setName((util::node_info(n) + "_matmul").c_str());

auto mm_out = mm_layer->getOutput(0);

auto alpha = args[4].unwrapToScalar();
if (alpha.to<float>() != 1.) {
auto alpha_tensor = scalar_to_tensor(ctx, alpha);
auto alpha_layer = add_elementwise(
ctx,
nvinfer1::ElementWiseOperation::kPROD,
mm_out,
alpha_tensor,
util::node_info(n) + std::string("_alpha_mul"));
TORCHTRT_CHECK(alpha_layer, "Unable to create alpha_mul layer from node: " << *n);
mm_out = alpha_layer->getOutput(0);
}

auto beta = args[3].unwrapToScalar();
// If beta is 0, then input will be ignored, and nan and inf in it will not be propagated.
if (beta.to<float>() != 0.) {
if (beta.to<float>() != 1.) {
auto beta_tensor = scalar_to_tensor(ctx, beta);
auto beta_layer = add_elementwise(
ctx,
nvinfer1::ElementWiseOperation::kPROD,
self,
beta_tensor,
util::node_info(n) + std::string("_beta_mul"));
TORCHTRT_CHECK(beta_layer, "Unable to create beta_mul layer from node: " << *n);
self = beta_layer->getOutput(0);
}
auto self_add_layer = add_elementwise(
ctx,
nvinfer1::ElementWiseOperation::kSUM,
self,
mm_out,
util::node_info(n) + std::string("_self_add"));
TORCHTRT_CHECK(self_add_layer, "Unable to create self_add layer from node: " << *n);
mm_out = self_add_layer->getOutput(0);
}

auto out_tensor = ctx->AssociateValueAndTensor(n->outputs()[0], mm_out);
LOG_DEBUG("Output tensor shape: " << out_tensor->getDimensions());
return true;
}});
Expand Down
69 changes: 69 additions & 0 deletions tests/core/conversion/converters/test_matrix_multiply.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -67,3 +67,72 @@ TEST(Converters, ATenBMMConvertsCorrectly) {

ASSERT_TRUE(torch_tensorrt::tests::util::almostEqual(jit_results[0], trt, 2e-6));
}

TEST(Converters, ATenBADDBMMConvertsCorrectly) {
const auto graph = R"IR(
graph(%0 : Tensor, %1 : Tensor, %2 : Tensor):
%a : float = prim::Constant[value=1.5]()
%b : float = prim::Constant[value=.2]()
%2 : Tensor = aten::baddbmm(%0, %1, %2, %b, %a)
return (%2))IR";

auto g = std::make_shared<torch::jit::Graph>();
torch::jit::parseIR(graph, g.get());

auto self = at::randn({10, 3, 5}, {at::kCUDA});
auto bat1 = at::randn({10, 3, 4}, {at::kCUDA});
auto bat2 = at::randn({10, 4, 5}, {at::kCUDA});
auto params = torch_tensorrt::core::ir::get_static_params(g->inputs(), {});
auto jit_results = torch_tensorrt::tests::util::RunGraph(g, params, {self, bat1, bat2});

params = torch_tensorrt::core::ir::get_static_params(g->inputs(), {});
auto trt_results = torch_tensorrt::tests::util::RunGraphEngine(g, params, {self, bat1, bat2});

ASSERT_TRUE(torch_tensorrt::tests::util::almostEqual(jit_results[0], trt_results[0], 2e-6));
}

TEST(Converters, ATenBADDBMMAlphaBetaDisabledConvertsCorrectly) {
const auto graph = R"IR(
graph(%0 : Tensor, %1 : Tensor, %2 : Tensor):
%a : float = prim::Constant[value=1]()
%b : float = prim::Constant[value=0]()
%2 : Tensor = aten::baddbmm(%0, %1, %2, %b, %a)
return (%2))IR";

auto g = std::make_shared<torch::jit::Graph>();
torch::jit::parseIR(graph, g.get());

auto self = at::randn({10, 3, 5}, {at::kCUDA});
auto bat1 = at::randn({10, 3, 4}, {at::kCUDA});
auto bat2 = at::randn({10, 4, 5}, {at::kCUDA});
auto params = torch_tensorrt::core::ir::get_static_params(g->inputs(), {});
auto jit_results = torch_tensorrt::tests::util::RunGraph(g, params, {self, bat1, bat2});

params = torch_tensorrt::core::ir::get_static_params(g->inputs(), {});
auto trt_results = torch_tensorrt::tests::util::RunGraphEngine(g, params, {self, bat1, bat2});

ASSERT_TRUE(torch_tensorrt::tests::util::almostEqual(jit_results[0], trt_results[0], 2e-6));
}

TEST(Converters, ATenBADDBMMScalarDefaultsConvertsCorrectly) {
const auto graph = R"IR(
graph(%0 : Tensor, %1 : Tensor, %2 : Tensor):
%a : float = prim::Constant[value=1]()
%b : float = prim::Constant[value=1]()
%2 : Tensor = aten::baddbmm(%0, %1, %2, %b, %a)
return (%2))IR";

auto g = std::make_shared<torch::jit::Graph>();
torch::jit::parseIR(graph, g.get());

auto self = at::randn({10, 3, 5}, {at::kCUDA});
auto bat1 = at::randn({10, 3, 4}, {at::kCUDA});
auto bat2 = at::randn({10, 4, 5}, {at::kCUDA});
auto params = torch_tensorrt::core::ir::get_static_params(g->inputs(), {});
auto jit_results = torch_tensorrt::tests::util::RunGraph(g, params, {self, bat1, bat2});

params = torch_tensorrt::core::ir::get_static_params(g->inputs(), {});
auto trt_results = torch_tensorrt::tests::util::RunGraphEngine(g, params, {self, bat1, bat2});

ASSERT_TRUE(torch_tensorrt::tests::util::almostEqual(jit_results[0], trt_results[0], 2e-6));
}

0 comments on commit 78b571c

Please sign in to comment.