Skip to content

Commit

Permalink
feat: support type promotion in aten::cat converter (#1911)
Browse files Browse the repository at this point in the history
  • Loading branch information
mfeliz-cruise authored and narendasan committed May 30, 2023
1 parent 453266b commit e5a43df
Show file tree
Hide file tree
Showing 3 changed files with 93 additions and 0 deletions.
2 changes: 2 additions & 0 deletions core/conversion/converters/converter_util.h
Original file line number Diff line number Diff line change
Expand Up @@ -96,6 +96,8 @@ nvinfer1::ITensor* get_slice_size(

nvinfer1::ITensor* scalar_to_tensor(ConversionCtx* ctx, at::Scalar s);

nvinfer1::DataType promote_types(nvinfer1::DataType type_a, nvinfer1::DataType type_b);

} // namespace converters
} // namespace conversion
} // namespace core
Expand Down
12 changes: 12 additions & 0 deletions core/conversion/converters/impl/concat.cpp
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
#include "core/conversion/converters/converter_util.h"
#include "core/conversion/converters/converters.h"
#include "core/conversion/tensorcontainer/TensorContainer.h"
#include "core/util/prelude.h"
Expand Down Expand Up @@ -27,6 +28,17 @@ auto cat_registrations TORCHTRT_UNUSED = RegisterNodeConversionPatterns()
}
}

auto promo_dtype = tensors[0]->getType();
for(size_t idx = 1UL; idx < tensors.size(); ++idx){
promo_dtype = promote_types(promo_dtype, tensors[idx]->getType());
}

for(size_t idx = 0UL; idx < tensors.size(); ++idx){
if(tensors[idx]->getType() != promo_dtype){
tensors[idx] = castITensor(ctx, tensors[idx], promo_dtype, util::node_info(n) + "_cast_" + std::to_string(idx));
}
}

if (dim < 0) {
dim = tensors[0]->getDimensions().nbDims + dim;
}
Expand Down
79 changes: 79 additions & 0 deletions tests/core/conversion/converters/test_concat.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,85 @@ TEST(Converters, ATenCatPureTensorConvertsCorrectly) {
torch_tensorrt::tests::util::almostEqual(jit_results[0], trt_results[0].reshape_as(jit_results[0]), 2e-6));
}

TEST(Converters, ATenCatFloatIntConvertsCorrectly) {
const auto graph = R"IR(
graph(%0 : Tensor,
%1 : Tensor):
%2 : Tensor[] = prim::ListConstruct(%0, %1)
%3 : int = prim::Constant[value=0]()
%4 : Tensor = aten::cat(%2, %3)
return (%4))IR";

auto g = std::make_shared<torch::jit::Graph>();
torch::jit::parseIR(graph, g.get());

auto in1 = at::randint(1, 10, {5}, {at::kCUDA}).to(at::kFloat);
auto in2 = at::randint(1, 10, {5}, {at::kCUDA}).to(at::kInt);

auto params = torch_tensorrt::core::ir::get_static_params(g->inputs(), {});
auto jit_results = torch_tensorrt::tests::util::RunGraph(g, params, {in1, in2});

params = torch_tensorrt::core::ir::get_static_params(g->inputs(), {});
auto trt_results = torch_tensorrt::tests::util::RunGraphEngine(g, params, {in1, in2});

ASSERT_TRUE(torch_tensorrt::tests::util::almostEqual(jit_results[0], trt_results[0], 2e-6));
}

TEST(Converters, ATenCatIntHalfIntHalfConvertsCorrectly) {
const auto graph = R"IR(
graph(%0 : Tensor,
%1 : Tensor,
%2 : Tensor,
%3 : Tensor):
%2 : Tensor[] = prim::ListConstruct(%0, %1, %2, %3)
%3 : int = prim::Constant[value=0]()
%4 : Tensor = aten::cat(%2, %3)
return (%4))IR";

auto g = std::make_shared<torch::jit::Graph>();
torch::jit::parseIR(graph, g.get());

auto in1 = at::randint(1, 10, {5}, {at::kCUDA}).to(at::kInt);
auto in2 = at::randint(1, 10, {5}, {at::kCUDA}).to(at::kHalf);
auto in3 = at::randint(1, 10, {5}, {at::kCUDA}).to(at::kInt);
auto in4 = at::randint(1, 10, {5}, {at::kCUDA}).to(at::kHalf);

auto params = torch_tensorrt::core::ir::get_static_params(g->inputs(), {});
auto jit_results = torch_tensorrt::tests::util::RunGraph(g, params, {in1, in2, in3, in4});

params = torch_tensorrt::core::ir::get_static_params(g->inputs(), {});
auto trt_results =
torch_tensorrt::tests::util::RunGraphEngine(g, params, {in1, in2, in3, in4}, nvinfer1::DataType::kHALF);

ASSERT_TRUE(torch_tensorrt::tests::util::almostEqual(jit_results[0], trt_results[0], 2e-6));
}

TEST(Converters, ATenCatHalfIntFloatConvertsCorrectly) {
const auto graph = R"IR(
graph(%0 : Tensor,
%1 : Tensor,
%2 : Tensor):
%2 : Tensor[] = prim::ListConstruct(%0, %1, %2)
%3 : int = prim::Constant[value=0]()
%4 : Tensor = aten::cat(%2, %3)
return (%4))IR";

auto g = std::make_shared<torch::jit::Graph>();
torch::jit::parseIR(graph, g.get());

auto in1 = at::randint(1, 10, {5}, {at::kCUDA}).to(at::kInt);
auto in2 = at::randint(1, 10, {5}, {at::kCUDA}).to(at::kHalf);
auto in3 = at::randint(1, 10, {5}, {at::kCUDA}).to(at::kFloat);

auto params = torch_tensorrt::core::ir::get_static_params(g->inputs(), {});
auto jit_results = torch_tensorrt::tests::util::RunGraph(g, params, {in1, in2, in3});

params = torch_tensorrt::core::ir::get_static_params(g->inputs(), {});
auto trt_results = torch_tensorrt::tests::util::RunGraphEngine(g, params, {in1, in2, in3});

ASSERT_TRUE(torch_tensorrt::tests::util::almostEqual(jit_results[0], trt_results[0], 2e-6));
}

TEST(Converters, ATenCatDiffTensorConvertsCorrectly) {
const auto graph = R"IR(
graph(%0 : Tensor,
Expand Down

0 comments on commit e5a43df

Please sign in to comment.