Skip to content

Commit

Permalink
[feat] Add support for argmax and argmin (#1312)
Browse files Browse the repository at this point in the history
* [feat] Add support for argmax and argmin

Adds support for aten::argmax and aten::argmin.

Fixes # (issue)

Please delete options that are not relevant and/or add your own.

- Bug fix (non-breaking change which fixes an issue)
- New feature (non-breaking change which adds functionality)
- Breaking change (fix or feature that would cause existing functionality to not work as expected)
- This change requires a documentation update

- [ ] My code follows the style guidelines of this project (You can use the linters)
- [ ] I have performed a self-review of my own code
- [ ] I have commented my code, particularly in hard-to-understand areas and hacks
- [ ] I have made corresponding changes to the documentation
- [ ] I have added tests to verify my fix or my feature
- [ ] New and existing unit tests pass locally with my changes
- [ ] I have added the relevant labels to my PR in so that relevant reviewers are notified

* move max.cpp tests to test_max.cpp no functional change

* fix permissions on max.cpp
  • Loading branch information
mfeliz-cruise authored Sep 2, 2022
1 parent 2af5942 commit 9db2852
Show file tree
Hide file tree
Showing 4 changed files with 241 additions and 66 deletions.
130 changes: 89 additions & 41 deletions core/conversion/converters/impl/max.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -13,47 +13,95 @@ namespace conversion {
namespace converters {
namespace impl {
namespace {
auto max_registrations TORCHTRT_UNUSED = RegisterNodeConversionPatterns().pattern(
{"aten::max.dim(Tensor self, int dim, bool keepdim=False) -> (Tensor values, Tensor indices)",
[](ConversionCtx* ctx, const torch::jit::Node* n, args& args) -> bool {
auto self = args[0].ITensorOrFreeze(ctx);
auto dim = args[1].unwrapToInt();
auto keep_dims = args[2].unwrapToBool();
auto selfDim = util::toVec(self->getDimensions());
if (dim < 0) {
dim = selfDim.size() + dim;
}
uint32_t shiftDim = 1 << dim;
auto TopKOperation = nvinfer1::TopKOperation::kMAX;
auto topk_layer = ctx->net->addTopK(*self, TopKOperation, 1, shiftDim);
TORCHTRT_CHECK(topk_layer, "Unable to create max layer from node: " << *n);
auto topk_dims = util::toVec(topk_layer->getOutput(0)->getDimensions());

nvinfer1::ITensor* out0 = nullptr;
nvinfer1::ITensor* out1 = nullptr;
if (!keep_dims) {
if (topk_dims[dim] == 1) {
auto squeeze_layer = ctx->net->addShuffle(*topk_layer->getOutput(0));
squeeze_layer->setReshapeDimensions(util::squeezeDims(topk_layer->getOutput(0)->getDimensions(), dim));
TORCHTRT_CHECK(squeeze_layer, "Unable to create squeeze_layer layer from node: " << *n);
out0 = ctx->AssociateValueAndTensor(n->outputs()[0], squeeze_layer->getOutput(0));

auto squeeze_layer_indices = ctx->net->addShuffle(*topk_layer->getOutput(1));
squeeze_layer_indices->setReshapeDimensions(
util::squeezeDims(topk_layer->getOutput(1)->getDimensions(), dim));
TORCHTRT_CHECK(squeeze_layer_indices, "Unable to create squeeze_layer_indices layer from node: " << *n);
out1 = ctx->AssociateValueAndTensor(n->outputs()[1], squeeze_layer_indices->getOutput(0));
}
} else {
out0 = ctx->AssociateValueAndTensor(n->outputs()[0], topk_layer->getOutput(0));
out1 = ctx->AssociateValueAndTensor(n->outputs()[1], topk_layer->getOutput(1));
}

LOG_DEBUG("Output tensor(0) shape: " << out0->getDimensions());
LOG_DEBUG("Output tensor(1) shape: " << out1->getDimensions());

return true;
}});

bool min_max_dim(ConversionCtx* ctx, const torch::jit::Node* n, args& args, nvinfer1::TopKOperation topKOperation) {
auto self = args[0].ITensorOrFreeze(ctx);
auto dim = args[1].unwrapToInt();
auto keep_dims = args[2].unwrapToBool();
auto selfDim = util::toVec(self->getDimensions());
if (dim < 0) {
dim = selfDim.size() + dim;
}
uint32_t reduce_axes_mask = 1 << dim;
auto topk_layer = ctx->net->addTopK(*self, topKOperation, 1, reduce_axes_mask);
TORCHTRT_CHECK(topk_layer, "Unable to create topk layer from node: " << *n);
auto topk_dims = util::toVec(topk_layer->getOutput(0)->getDimensions());

nvinfer1::ITensor* out0 = nullptr;
nvinfer1::ITensor* out1 = nullptr;
if (!keep_dims) {
TORCHTRT_CHECK(topk_dims[dim] == 1, "Unexpected size in squeeze dimension. Expected: 1 Actual: " << topk_dims[dim]);
auto squeeze_layer = ctx->net->addShuffle(*topk_layer->getOutput(0));
squeeze_layer->setReshapeDimensions(util::squeezeDims(topk_layer->getOutput(0)->getDimensions(), dim));
TORCHTRT_CHECK(squeeze_layer, "Unable to create squeeze_layer layer from node: " << *n);
out0 = ctx->AssociateValueAndTensor(n->outputs()[0], squeeze_layer->getOutput(0));

auto squeeze_layer_indices = ctx->net->addShuffle(*topk_layer->getOutput(1));
squeeze_layer_indices->setReshapeDimensions(util::squeezeDims(topk_layer->getOutput(1)->getDimensions(), dim));
TORCHTRT_CHECK(squeeze_layer_indices, "Unable to create squeeze_layer_indices layer from node: " << *n);
out1 = ctx->AssociateValueAndTensor(n->outputs()[1], squeeze_layer_indices->getOutput(0));
} else {
out0 = ctx->AssociateValueAndTensor(n->outputs()[0], topk_layer->getOutput(0));
out1 = ctx->AssociateValueAndTensor(n->outputs()[1], topk_layer->getOutput(1));
}

LOG_DEBUG("Output tensor(0) shape: " << out0->getDimensions());
LOG_DEBUG("Output tensor(1) shape: " << out1->getDimensions());

return true;
}

bool arg_min_max(ConversionCtx* ctx, const torch::jit::Node* n, args& args, nvinfer1::TopKOperation topKOperation) {
auto self = args[0].ITensorOrFreeze(ctx);
auto dim = args[1].unwrapToInt();
auto keep_dims = args[2].unwrapToBool();
auto selfDim = util::toVec(self->getDimensions());
if (dim < 0) {
dim = selfDim.size() + dim;
}
uint32_t reduce_axes_mask = 1 << dim;
auto topk_layer = ctx->net->addTopK(*self, topKOperation, 1, reduce_axes_mask);
TORCHTRT_CHECK(topk_layer, "Unable to create topk layer from node: " << *n);
auto topk_dims = util::toVec(topk_layer->getOutput(0)->getDimensions());

nvinfer1::ITensor* out = nullptr;
if (!keep_dims) {
TORCHTRT_CHECK(topk_dims[dim] == 1, "Unexpected size in squeeze dimension. Expected: 1 Actual: " << topk_dims[dim]);
auto squeeze_layer_indices = ctx->net->addShuffle(*topk_layer->getOutput(1));
squeeze_layer_indices->setReshapeDimensions(util::squeezeDims(topk_layer->getOutput(1)->getDimensions(), dim));
TORCHTRT_CHECK(squeeze_layer_indices, "Unable to create squeeze_layer_indices layer from node: " << *n);
out = ctx->AssociateValueAndTensor(n->outputs()[0], squeeze_layer_indices->getOutput(0));
} else {
out = ctx->AssociateValueAndTensor(n->outputs()[0], topk_layer->getOutput(1));
}

LOG_DEBUG("Output tensor shape: " << out->getDimensions());

return true;
}

auto max_registrations TORCHTRT_UNUSED =
RegisterNodeConversionPatterns()
.pattern(
{"aten::max.dim(Tensor self, int dim, bool keepdim=False) -> (Tensor values, Tensor indices)",
[](ConversionCtx* ctx, const torch::jit::Node* n, args& args) -> bool {
return min_max_dim(ctx, n, args, nvinfer1::TopKOperation::kMAX);
}})
.pattern(
{"aten::min.dim(Tensor self, int dim, bool keepdim=False) -> (Tensor values, Tensor indices)",
[](ConversionCtx* ctx, const torch::jit::Node* n, args& args) -> bool {
return min_max_dim(ctx, n, args, nvinfer1::TopKOperation::kMIN);
}})
.pattern(
{"aten::argmax(Tensor self, int dim, bool keepdim=False) -> (Tensor)",
[](ConversionCtx* ctx, const torch::jit::Node* n, args& args) -> bool {
return arg_min_max(ctx, n, args, nvinfer1::TopKOperation::kMAX);
}})
.pattern(
{"aten::argmin(Tensor self, int dim, bool keepdim=False) -> (Tensor)",
[](ConversionCtx* ctx, const torch::jit::Node* n, args& args) -> bool {
return arg_min_max(ctx, n, args, nvinfer1::TopKOperation::kMIN);
}});
} // namespace
} // namespace impl
} // namespace converters
Expand Down
5 changes: 5 additions & 0 deletions tests/core/conversion/converters/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -71,6 +71,10 @@ converter_test(
name = "test_matrix_multiply",
)

converter_test(
name = "test_max",
)

converter_test(
name = "test_normalize",
)
Expand Down Expand Up @@ -156,6 +160,7 @@ test_suite(
":test_linear",
":test_lstm_cell",
":test_matrix_multiply",
":test_max",
":test_normalize",
":test_pooling",
":test_reduce",
Expand Down
147 changes: 147 additions & 0 deletions tests/core/conversion/converters/test_max.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,147 @@
#include <string>
#include "core/compiler.h"
#include "gtest/gtest.h"
#include "tests/util/util.h"
#include "torch/csrc/jit/ir/irparser.h"

TEST(Converters, ATenMaxDimConvertsCorrectly) {
const auto graph = R"IR(
graph(%x.1 : Tensor):
%2 : int = prim::Constant[value=0]()
%3 : bool = prim::Constant[value=0]()
%4 : Tensor, %5 : Tensor = aten::max(%x.1, %2, %3)
return (%4, %5))IR";

auto g = std::make_shared<torch::jit::Graph>();
torch::jit::parseIR(graph, g.get());

auto in = at::rand({2, 3, 5, 5}, {at::kCUDA});

auto params = torch_tensorrt::core::ir::get_static_params(g->inputs(), {});
auto jit_results = torch_tensorrt::tests::util::RunGraph(g, params, {in});

params = torch_tensorrt::core::ir::get_static_params(g->inputs(), {});
auto trt_results = torch_tensorrt::tests::util::RunGraphEngine(g, params, {in});

ASSERT_TRUE(
torch_tensorrt::tests::util::almostEqual(jit_results[0], trt_results[0].reshape_as(jit_results[0]), 2e-6));
ASSERT_TRUE(
torch_tensorrt::tests::util::almostEqual(jit_results[1], trt_results[1].reshape_as(jit_results[1]), 2e-6));
}

TEST(Converters, ATenMinDimConvertsCorrectly) {
const auto graph = R"IR(
graph(%x.1 : Tensor):
%2 : int = prim::Constant[value=0]()
%3 : bool = prim::Constant[value=0]()
%4 : Tensor, %5 : Tensor = aten::min(%x.1, %2, %3)
return (%4, %5))IR";

auto g = std::make_shared<torch::jit::Graph>();
torch::jit::parseIR(graph, g.get());

auto in = at::rand({2, 3, 5, 5}, {at::kCUDA});

auto params = torch_tensorrt::core::ir::get_static_params(g->inputs(), {});
auto jit_results = torch_tensorrt::tests::util::RunGraph(g, params, {in});

params = torch_tensorrt::core::ir::get_static_params(g->inputs(), {});
auto trt_results = torch_tensorrt::tests::util::RunGraphEngine(g, params, {in});

ASSERT_TRUE(
torch_tensorrt::tests::util::almostEqual(jit_results[0], trt_results[0].reshape_as(jit_results[0]), 2e-6));
ASSERT_TRUE(
torch_tensorrt::tests::util::almostEqual(jit_results[1], trt_results[1].reshape_as(jit_results[1]), 2e-6));
}

TEST(Converters, ATenArgMaxConvertsCorrectly) {
const auto graph = R"IR(
graph(%x.1 : Tensor):
%2 : int = prim::Constant[value=0]()
%3 : bool = prim::Constant[value=0]()
%4 : Tensor = aten::argmax(%x.1, %2, %3)
return (%4))IR";

auto g = std::make_shared<torch::jit::Graph>();
torch::jit::parseIR(graph, g.get());

auto in = at::rand({2, 3, 5, 5}, {at::kCUDA});

auto params = torch_tensorrt::core::ir::get_static_params(g->inputs(), {});
auto jit_results = torch_tensorrt::tests::util::RunGraph(g, params, {in});

params = torch_tensorrt::core::ir::get_static_params(g->inputs(), {});
auto trt_results = torch_tensorrt::tests::util::RunGraphEngine(g, params, {in});

ASSERT_TRUE(
torch_tensorrt::tests::util::almostEqual(jit_results[0], trt_results[0].reshape_as(jit_results[0]), 2e-6));
}

TEST(Converters, ATenArgMaxKeepdimConvertsCorrectly) {
const auto graph = R"IR(
graph(%x.1 : Tensor):
%2 : int = prim::Constant[value=1]()
%3 : bool = prim::Constant[value=1]()
%4 : Tensor = aten::argmax(%x.1, %2, %3)
return (%4))IR";

auto g = std::make_shared<torch::jit::Graph>();
torch::jit::parseIR(graph, g.get());

auto in = at::rand({2, 3, 5, 5}, {at::kCUDA});

auto params = torch_tensorrt::core::ir::get_static_params(g->inputs(), {});
auto jit_results = torch_tensorrt::tests::util::RunGraph(g, params, {in});

params = torch_tensorrt::core::ir::get_static_params(g->inputs(), {});
auto trt_results = torch_tensorrt::tests::util::RunGraphEngine(g, params, {in});

ASSERT_TRUE(
torch_tensorrt::tests::util::almostEqual(jit_results[0], trt_results[0].reshape_as(jit_results[0]), 2e-6));
}

TEST(Converters, ATenArgMinConvertsCorrectly) {
const auto graph = R"IR(
graph(%x.1 : Tensor):
%2 : int = prim::Constant[value=0]()
%3 : bool = prim::Constant[value=0]()
%4 : Tensor = aten::argmin(%x.1, %2, %3)
return (%4))IR";

auto g = std::make_shared<torch::jit::Graph>();
torch::jit::parseIR(graph, g.get());

auto in = at::rand({2, 3, 5, 5}, {at::kCUDA});

auto params = torch_tensorrt::core::ir::get_static_params(g->inputs(), {});
auto jit_results = torch_tensorrt::tests::util::RunGraph(g, params, {in});

params = torch_tensorrt::core::ir::get_static_params(g->inputs(), {});
auto trt_results = torch_tensorrt::tests::util::RunGraphEngine(g, params, {in});

ASSERT_TRUE(
torch_tensorrt::tests::util::almostEqual(jit_results[0], trt_results[0].reshape_as(jit_results[0]), 2e-6));
}

TEST(Converters, ATenArgMinKeepdimConvertsCorrectly) {
const auto graph = R"IR(
graph(%x.1 : Tensor):
%2 : int = prim::Constant[value=1]()
%3 : bool = prim::Constant[value=1]()
%4 : Tensor = aten::argmin(%x.1, %2, %3)
return (%4))IR";

auto g = std::make_shared<torch::jit::Graph>();
torch::jit::parseIR(graph, g.get());

auto in = at::rand({2, 3, 5, 5}, {at::kCUDA});

auto params = torch_tensorrt::core::ir::get_static_params(g->inputs(), {});
auto jit_results = torch_tensorrt::tests::util::RunGraph(g, params, {in});

params = torch_tensorrt::core::ir::get_static_params(g->inputs(), {});
auto trt_results = torch_tensorrt::tests::util::RunGraphEngine(g, params, {in});

ASSERT_TRUE(
torch_tensorrt::tests::util::almostEqual(jit_results[0], trt_results[0].reshape_as(jit_results[0]), 2e-6));
}
25 changes: 0 additions & 25 deletions tests/core/conversion/converters/test_topk.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -30,28 +30,3 @@ TEST(Converters, ATenTopKConvertsCorrectly) {
ASSERT_TRUE(
torch_tensorrt::tests::util::almostEqual(jit_results[1], trt_results[1].reshape_as(jit_results[1]), 2e-6));
}

TEST(Converters, ATenMaxDimConvertsCorrectly) {
const auto graph = R"IR(
graph(%x.1 : Tensor):
%2 : int = prim::Constant[value=0]()
%3 : bool = prim::Constant[value=0]()
%4 : Tensor, %5 : Tensor = aten::max(%x.1, %2, %3)
return (%4, %5))IR";

auto g = std::make_shared<torch::jit::Graph>();
torch::jit::parseIR(graph, g.get());

auto in = at::rand({2, 3, 5, 5}, {at::kCUDA});

auto params = torch_tensorrt::core::ir::get_static_params(g->inputs(), {});
auto jit_results = torch_tensorrt::tests::util::RunGraph(g, params, {in});

params = torch_tensorrt::core::ir::get_static_params(g->inputs(), {});
auto trt_results = torch_tensorrt::tests::util::RunGraphEngine(g, params, {in});

ASSERT_TRUE(
torch_tensorrt::tests::util::almostEqual(jit_results[0], trt_results[0].reshape_as(jit_results[0]), 2e-6));
ASSERT_TRUE(
torch_tensorrt::tests::util::almostEqual(jit_results[1], trt_results[1].reshape_as(jit_results[1]), 2e-6));
}

0 comments on commit 9db2852

Please sign in to comment.