diff --git a/core/conversion/converters/impl/max.cpp b/core/conversion/converters/impl/max.cpp index 175cc75461..3ccf165bbe 100644 --- a/core/conversion/converters/impl/max.cpp +++ b/core/conversion/converters/impl/max.cpp @@ -13,47 +13,95 @@ namespace conversion { namespace converters { namespace impl { namespace { -auto max_registrations TORCHTRT_UNUSED = RegisterNodeConversionPatterns().pattern( - {"aten::max.dim(Tensor self, int dim, bool keepdim=False) -> (Tensor values, Tensor indices)", - [](ConversionCtx* ctx, const torch::jit::Node* n, args& args) -> bool { - auto self = args[0].ITensorOrFreeze(ctx); - auto dim = args[1].unwrapToInt(); - auto keep_dims = args[2].unwrapToBool(); - auto selfDim = util::toVec(self->getDimensions()); - if (dim < 0) { - dim = selfDim.size() + dim; - } - uint32_t shiftDim = 1 << dim; - auto TopKOperation = nvinfer1::TopKOperation::kMAX; - auto topk_layer = ctx->net->addTopK(*self, TopKOperation, 1, shiftDim); - TORCHTRT_CHECK(topk_layer, "Unable to create max layer from node: " << *n); - auto topk_dims = util::toVec(topk_layer->getOutput(0)->getDimensions()); - - nvinfer1::ITensor* out0 = nullptr; - nvinfer1::ITensor* out1 = nullptr; - if (!keep_dims) { - if (topk_dims[dim] == 1) { - auto squeeze_layer = ctx->net->addShuffle(*topk_layer->getOutput(0)); - squeeze_layer->setReshapeDimensions(util::squeezeDims(topk_layer->getOutput(0)->getDimensions(), dim)); - TORCHTRT_CHECK(squeeze_layer, "Unable to create squeeze_layer layer from node: " << *n); - out0 = ctx->AssociateValueAndTensor(n->outputs()[0], squeeze_layer->getOutput(0)); - - auto squeeze_layer_indices = ctx->net->addShuffle(*topk_layer->getOutput(1)); - squeeze_layer_indices->setReshapeDimensions( - util::squeezeDims(topk_layer->getOutput(1)->getDimensions(), dim)); - TORCHTRT_CHECK(squeeze_layer_indices, "Unable to create squeeze_layer_indices layer from node: " << *n); - out1 = ctx->AssociateValueAndTensor(n->outputs()[1], squeeze_layer_indices->getOutput(0)); - } - } else { - out0 = ctx->AssociateValueAndTensor(n->outputs()[0], topk_layer->getOutput(0)); - out1 = ctx->AssociateValueAndTensor(n->outputs()[1], topk_layer->getOutput(1)); - } - - LOG_DEBUG("Output tensor(0) shape: " << out0->getDimensions()); - LOG_DEBUG("Output tensor(1) shape: " << out1->getDimensions()); - - return true; - }}); + +bool min_max_dim(ConversionCtx* ctx, const torch::jit::Node* n, args& args, nvinfer1::TopKOperation topKOperation) { + auto self = args[0].ITensorOrFreeze(ctx); + auto dim = args[1].unwrapToInt(); + auto keep_dims = args[2].unwrapToBool(); + auto selfDim = util::toVec(self->getDimensions()); + if (dim < 0) { + dim = selfDim.size() + dim; + } + uint32_t reduce_axes_mask = 1 << dim; + auto topk_layer = ctx->net->addTopK(*self, topKOperation, 1, reduce_axes_mask); + TORCHTRT_CHECK(topk_layer, "Unable to create topk layer from node: " << *n); + auto topk_dims = util::toVec(topk_layer->getOutput(0)->getDimensions()); + + nvinfer1::ITensor* out0 = nullptr; + nvinfer1::ITensor* out1 = nullptr; + if (!keep_dims) { + TORCHTRT_CHECK(topk_dims[dim] == 1, "Unexpected size in squeeze dimension. Expected: 1 Actual: " << topk_dims[dim]); + auto squeeze_layer = ctx->net->addShuffle(*topk_layer->getOutput(0)); + squeeze_layer->setReshapeDimensions(util::squeezeDims(topk_layer->getOutput(0)->getDimensions(), dim)); + TORCHTRT_CHECK(squeeze_layer, "Unable to create squeeze_layer layer from node: " << *n); + out0 = ctx->AssociateValueAndTensor(n->outputs()[0], squeeze_layer->getOutput(0)); + + auto squeeze_layer_indices = ctx->net->addShuffle(*topk_layer->getOutput(1)); + squeeze_layer_indices->setReshapeDimensions(util::squeezeDims(topk_layer->getOutput(1)->getDimensions(), dim)); + TORCHTRT_CHECK(squeeze_layer_indices, "Unable to create squeeze_layer_indices layer from node: " << *n); + out1 = ctx->AssociateValueAndTensor(n->outputs()[1], squeeze_layer_indices->getOutput(0)); + } else { + out0 = ctx->AssociateValueAndTensor(n->outputs()[0], topk_layer->getOutput(0)); + out1 = ctx->AssociateValueAndTensor(n->outputs()[1], topk_layer->getOutput(1)); + } + + LOG_DEBUG("Output tensor(0) shape: " << out0->getDimensions()); + LOG_DEBUG("Output tensor(1) shape: " << out1->getDimensions()); + + return true; +} + +bool arg_min_max(ConversionCtx* ctx, const torch::jit::Node* n, args& args, nvinfer1::TopKOperation topKOperation) { + auto self = args[0].ITensorOrFreeze(ctx); + auto dim = args[1].unwrapToInt(); + auto keep_dims = args[2].unwrapToBool(); + auto selfDim = util::toVec(self->getDimensions()); + if (dim < 0) { + dim = selfDim.size() + dim; + } + uint32_t reduce_axes_mask = 1 << dim; + auto topk_layer = ctx->net->addTopK(*self, topKOperation, 1, reduce_axes_mask); + TORCHTRT_CHECK(topk_layer, "Unable to create topk layer from node: " << *n); + auto topk_dims = util::toVec(topk_layer->getOutput(0)->getDimensions()); + + nvinfer1::ITensor* out = nullptr; + if (!keep_dims) { + TORCHTRT_CHECK(topk_dims[dim] == 1, "Unexpected size in squeeze dimension. Expected: 1 Actual: " << topk_dims[dim]); + auto squeeze_layer_indices = ctx->net->addShuffle(*topk_layer->getOutput(1)); + squeeze_layer_indices->setReshapeDimensions(util::squeezeDims(topk_layer->getOutput(1)->getDimensions(), dim)); + TORCHTRT_CHECK(squeeze_layer_indices, "Unable to create squeeze_layer_indices layer from node: " << *n); + out = ctx->AssociateValueAndTensor(n->outputs()[0], squeeze_layer_indices->getOutput(0)); + } else { + out = ctx->AssociateValueAndTensor(n->outputs()[0], topk_layer->getOutput(1)); + } + + LOG_DEBUG("Output tensor shape: " << out->getDimensions()); + + return true; +} + +auto max_registrations TORCHTRT_UNUSED = + RegisterNodeConversionPatterns() + .pattern( + {"aten::max.dim(Tensor self, int dim, bool keepdim=False) -> (Tensor values, Tensor indices)", + [](ConversionCtx* ctx, const torch::jit::Node* n, args& args) -> bool { + return min_max_dim(ctx, n, args, nvinfer1::TopKOperation::kMAX); + }}) + .pattern( + {"aten::min.dim(Tensor self, int dim, bool keepdim=False) -> (Tensor values, Tensor indices)", + [](ConversionCtx* ctx, const torch::jit::Node* n, args& args) -> bool { + return min_max_dim(ctx, n, args, nvinfer1::TopKOperation::kMIN); + }}) + .pattern( + {"aten::argmax(Tensor self, int dim, bool keepdim=False) -> (Tensor)", + [](ConversionCtx* ctx, const torch::jit::Node* n, args& args) -> bool { + return arg_min_max(ctx, n, args, nvinfer1::TopKOperation::kMAX); + }}) + .pattern( + {"aten::argmin(Tensor self, int dim, bool keepdim=False) -> (Tensor)", + [](ConversionCtx* ctx, const torch::jit::Node* n, args& args) -> bool { + return arg_min_max(ctx, n, args, nvinfer1::TopKOperation::kMIN); + }}); } // namespace } // namespace impl } // namespace converters diff --git a/tests/core/conversion/converters/BUILD b/tests/core/conversion/converters/BUILD index 82bc2f7033..5246de4cf1 100644 --- a/tests/core/conversion/converters/BUILD +++ b/tests/core/conversion/converters/BUILD @@ -71,6 +71,10 @@ converter_test( name = "test_matrix_multiply", ) +converter_test( + name = "test_max", +) + converter_test( name = "test_normalize", ) @@ -156,6 +160,7 @@ test_suite( ":test_linear", ":test_lstm_cell", ":test_matrix_multiply", + ":test_max", ":test_normalize", ":test_pooling", ":test_reduce", diff --git a/tests/core/conversion/converters/test_max.cpp b/tests/core/conversion/converters/test_max.cpp new file mode 100644 index 0000000000..dfc2432c24 --- /dev/null +++ b/tests/core/conversion/converters/test_max.cpp @@ -0,0 +1,147 @@ +#include +#include "core/compiler.h" +#include "gtest/gtest.h" +#include "tests/util/util.h" +#include "torch/csrc/jit/ir/irparser.h" + +TEST(Converters, ATenMaxDimConvertsCorrectly) { + const auto graph = R"IR( + graph(%x.1 : Tensor): + %2 : int = prim::Constant[value=0]() + %3 : bool = prim::Constant[value=0]() + %4 : Tensor, %5 : Tensor = aten::max(%x.1, %2, %3) + return (%4, %5))IR"; + + auto g = std::make_shared(); + torch::jit::parseIR(graph, g.get()); + + auto in = at::rand({2, 3, 5, 5}, {at::kCUDA}); + + auto params = torch_tensorrt::core::ir::get_static_params(g->inputs(), {}); + auto jit_results = torch_tensorrt::tests::util::RunGraph(g, params, {in}); + + params = torch_tensorrt::core::ir::get_static_params(g->inputs(), {}); + auto trt_results = torch_tensorrt::tests::util::RunGraphEngine(g, params, {in}); + + ASSERT_TRUE( + torch_tensorrt::tests::util::almostEqual(jit_results[0], trt_results[0].reshape_as(jit_results[0]), 2e-6)); + ASSERT_TRUE( + torch_tensorrt::tests::util::almostEqual(jit_results[1], trt_results[1].reshape_as(jit_results[1]), 2e-6)); +} + +TEST(Converters, ATenMinDimConvertsCorrectly) { + const auto graph = R"IR( + graph(%x.1 : Tensor): + %2 : int = prim::Constant[value=0]() + %3 : bool = prim::Constant[value=0]() + %4 : Tensor, %5 : Tensor = aten::min(%x.1, %2, %3) + return (%4, %5))IR"; + + auto g = std::make_shared(); + torch::jit::parseIR(graph, g.get()); + + auto in = at::rand({2, 3, 5, 5}, {at::kCUDA}); + + auto params = torch_tensorrt::core::ir::get_static_params(g->inputs(), {}); + auto jit_results = torch_tensorrt::tests::util::RunGraph(g, params, {in}); + + params = torch_tensorrt::core::ir::get_static_params(g->inputs(), {}); + auto trt_results = torch_tensorrt::tests::util::RunGraphEngine(g, params, {in}); + + ASSERT_TRUE( + torch_tensorrt::tests::util::almostEqual(jit_results[0], trt_results[0].reshape_as(jit_results[0]), 2e-6)); + ASSERT_TRUE( + torch_tensorrt::tests::util::almostEqual(jit_results[1], trt_results[1].reshape_as(jit_results[1]), 2e-6)); +} + +TEST(Converters, ATenArgMaxConvertsCorrectly) { + const auto graph = R"IR( + graph(%x.1 : Tensor): + %2 : int = prim::Constant[value=0]() + %3 : bool = prim::Constant[value=0]() + %4 : Tensor = aten::argmax(%x.1, %2, %3) + return (%4))IR"; + + auto g = std::make_shared(); + torch::jit::parseIR(graph, g.get()); + + auto in = at::rand({2, 3, 5, 5}, {at::kCUDA}); + + auto params = torch_tensorrt::core::ir::get_static_params(g->inputs(), {}); + auto jit_results = torch_tensorrt::tests::util::RunGraph(g, params, {in}); + + params = torch_tensorrt::core::ir::get_static_params(g->inputs(), {}); + auto trt_results = torch_tensorrt::tests::util::RunGraphEngine(g, params, {in}); + + ASSERT_TRUE( + torch_tensorrt::tests::util::almostEqual(jit_results[0], trt_results[0].reshape_as(jit_results[0]), 2e-6)); +} + +TEST(Converters, ATenArgMaxKeepdimConvertsCorrectly) { + const auto graph = R"IR( + graph(%x.1 : Tensor): + %2 : int = prim::Constant[value=1]() + %3 : bool = prim::Constant[value=1]() + %4 : Tensor = aten::argmax(%x.1, %2, %3) + return (%4))IR"; + + auto g = std::make_shared(); + torch::jit::parseIR(graph, g.get()); + + auto in = at::rand({2, 3, 5, 5}, {at::kCUDA}); + + auto params = torch_tensorrt::core::ir::get_static_params(g->inputs(), {}); + auto jit_results = torch_tensorrt::tests::util::RunGraph(g, params, {in}); + + params = torch_tensorrt::core::ir::get_static_params(g->inputs(), {}); + auto trt_results = torch_tensorrt::tests::util::RunGraphEngine(g, params, {in}); + + ASSERT_TRUE( + torch_tensorrt::tests::util::almostEqual(jit_results[0], trt_results[0].reshape_as(jit_results[0]), 2e-6)); +} + +TEST(Converters, ATenArgMinConvertsCorrectly) { + const auto graph = R"IR( + graph(%x.1 : Tensor): + %2 : int = prim::Constant[value=0]() + %3 : bool = prim::Constant[value=0]() + %4 : Tensor = aten::argmin(%x.1, %2, %3) + return (%4))IR"; + + auto g = std::make_shared(); + torch::jit::parseIR(graph, g.get()); + + auto in = at::rand({2, 3, 5, 5}, {at::kCUDA}); + + auto params = torch_tensorrt::core::ir::get_static_params(g->inputs(), {}); + auto jit_results = torch_tensorrt::tests::util::RunGraph(g, params, {in}); + + params = torch_tensorrt::core::ir::get_static_params(g->inputs(), {}); + auto trt_results = torch_tensorrt::tests::util::RunGraphEngine(g, params, {in}); + + ASSERT_TRUE( + torch_tensorrt::tests::util::almostEqual(jit_results[0], trt_results[0].reshape_as(jit_results[0]), 2e-6)); +} + +TEST(Converters, ATenArgMinKeepdimConvertsCorrectly) { + const auto graph = R"IR( + graph(%x.1 : Tensor): + %2 : int = prim::Constant[value=1]() + %3 : bool = prim::Constant[value=1]() + %4 : Tensor = aten::argmin(%x.1, %2, %3) + return (%4))IR"; + + auto g = std::make_shared(); + torch::jit::parseIR(graph, g.get()); + + auto in = at::rand({2, 3, 5, 5}, {at::kCUDA}); + + auto params = torch_tensorrt::core::ir::get_static_params(g->inputs(), {}); + auto jit_results = torch_tensorrt::tests::util::RunGraph(g, params, {in}); + + params = torch_tensorrt::core::ir::get_static_params(g->inputs(), {}); + auto trt_results = torch_tensorrt::tests::util::RunGraphEngine(g, params, {in}); + + ASSERT_TRUE( + torch_tensorrt::tests::util::almostEqual(jit_results[0], trt_results[0].reshape_as(jit_results[0]), 2e-6)); +} diff --git a/tests/core/conversion/converters/test_topk.cpp b/tests/core/conversion/converters/test_topk.cpp index 1885493737..c53d209c1f 100644 --- a/tests/core/conversion/converters/test_topk.cpp +++ b/tests/core/conversion/converters/test_topk.cpp @@ -30,28 +30,3 @@ TEST(Converters, ATenTopKConvertsCorrectly) { ASSERT_TRUE( torch_tensorrt::tests::util::almostEqual(jit_results[1], trt_results[1].reshape_as(jit_results[1]), 2e-6)); } - -TEST(Converters, ATenMaxDimConvertsCorrectly) { - const auto graph = R"IR( - graph(%x.1 : Tensor): - %2 : int = prim::Constant[value=0]() - %3 : bool = prim::Constant[value=0]() - %4 : Tensor, %5 : Tensor = aten::max(%x.1, %2, %3) - return (%4, %5))IR"; - - auto g = std::make_shared(); - torch::jit::parseIR(graph, g.get()); - - auto in = at::rand({2, 3, 5, 5}, {at::kCUDA}); - - auto params = torch_tensorrt::core::ir::get_static_params(g->inputs(), {}); - auto jit_results = torch_tensorrt::tests::util::RunGraph(g, params, {in}); - - params = torch_tensorrt::core::ir::get_static_params(g->inputs(), {}); - auto trt_results = torch_tensorrt::tests::util::RunGraphEngine(g, params, {in}); - - ASSERT_TRUE( - torch_tensorrt::tests::util::almostEqual(jit_results[0], trt_results[0].reshape_as(jit_results[0]), 2e-6)); - ASSERT_TRUE( - torch_tensorrt::tests::util::almostEqual(jit_results[1], trt_results[1].reshape_as(jit_results[1]), 2e-6)); -}