Skip to content

Commit

Permalink
feat (//core/conversion) : Add converter for torch.repeat_interleave ( (
Browse files Browse the repository at this point in the history
#1313)

* added interleave_repeat int repeats converter

* fixed compile time errors

* added repeat_interleave tests, moved converter to expand file

* repeat_interleave passing tests for static input

* implementation and tests for dynamic input repeat_interleave

* dynamic shape checks

* reformatting
  • Loading branch information
blchu authored Aug 28, 2022
1 parent f921c35 commit f350699
Show file tree
Hide file tree
Showing 2 changed files with 334 additions and 0 deletions.
110 changes: 110 additions & 0 deletions core/conversion/converters/impl/expand.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -282,6 +282,116 @@ auto expand_registrations TORCHTRT_UNUSED =
auto out = ctx->AssociateValueAndTensor(n->outputs()[0], in);

LOG_DEBUG("Repeat layer output tensor shape: " << out->getDimensions());
return true;
}})
.pattern(
{"aten::repeat_interleave.self_int(Tensor self, int repeats, int? dim=None, *, int? output_size=None) -> (Tensor)",
[](ConversionCtx* ctx, const torch::jit::Node* n, args& args) -> bool {
auto self = args[0].ITensorOrFreeze(ctx);
auto repeats = args[1].unwrapToScalar().to<int>();

auto input_shape = self->getDimensions();

int dim;
if (args[2].IValue()->isNone()) {
dim = 0;

// Flatten self tensor
int size;
if (ctx->input_is_dynamic) {
// Set size to -1 if input is dynamic
size = -1;
} else {
size = 1;
for (int i = 0; i < input_shape.nbDims; i++) {
size *= input_shape.d[i];
}
}
auto flatten = ctx->net->addShuffle(*self);
TORCHTRT_CHECK(flatten, "Unable to create shuffle layer from node: " << *n);
flatten->setReshapeDimensions(util::toDims(std::vector<int64_t>({size})));
self = flatten->getOutput(0);
input_shape = self->getDimensions();
} else {
dim = args[2].unwrapToScalar().to<int>();
}

if (ctx->input_is_dynamic) {
int dynamic_dims = 0;
for (int idx = 0; idx < input_shape.nbDims; idx++) {
if (input_shape.d[idx] == -1) {
dynamic_dims++;
}
}

if (dynamic_dims > 1) {
TORCHTRT_THROW_ERROR(
"Repeat_interleave is currently not supported when target shape contains more than one dynamic dimension");
}
}

// Insert singleton dimension after desired repeat dimension
std::vector<int64_t> repeat_shape_vec;
for (int j = 0; j < input_shape.nbDims; j++) {
repeat_shape_vec.push_back(input_shape.d[j]);
if (j == dim) {
repeat_shape_vec.push_back(1);
}
}
auto expand = ctx->net->addShuffle(*self);
TORCHTRT_CHECK(expand, "Unable to create shuffle layer from node: " << *n);
auto repeat_shape_dims = util::toDims(repeat_shape_vec);
expand->setReshapeDimensions(repeat_shape_dims);

// Expand on newly created singleton dimension
repeat_shape_dims.d[dim + 1] = repeats;
std::vector<int64_t> start_vec(repeat_shape_dims.nbDims, 0);
auto start_dims = util::toDims(start_vec);

std::vector<int64_t> strides_vec(repeat_shape_dims.nbDims, 1);
strides_vec[dim + 1] = 0;
auto strides_dims = util::toDims(strides_vec);

auto slice = ctx->net->addSlice(*expand->getOutput(0), start_dims, repeat_shape_dims, strides_dims);

if (ctx->input_is_dynamic) {
auto start_tensor = tensor_to_const(ctx, torch::tensor(start_vec, torch::kInt32));

auto expand_output_shape = ctx->net->addShape(*expand->getOutput(0))->getOutput(0);
std::vector<int64_t> repeat_const_vec(repeat_shape_dims.nbDims, 1);
repeat_const_vec[dim + 1] = repeats;
auto repeat_const = tensor_to_const(ctx, torch::tensor(repeat_const_vec, torch::kInt32));
auto repeat_shape_tensor =
ctx->net
->addElementWise(*expand_output_shape, *repeat_const, nvinfer1::ElementWiseOperation::kPROD)
->getOutput(0);

auto strides_tensor = tensor_to_const(ctx, torch::tensor(strides_vec, torch::kInt32));
slice->setInput(1, *start_tensor);
slice->setInput(2, *repeat_shape_tensor);
slice->setInput(3, *strides_tensor);
}

// Collapse repeated dimension back into desired dimension
std::vector<int64_t> collapse_shape_vec;
for (int k = 0; k < repeat_shape_dims.nbDims; k++) {
if (k == dim) {
int64_t collapse_dim = repeat_shape_dims.d[k] * repeat_shape_dims.d[++k];
// Set dim size to -1 if repeat is being done on dynamic dim
collapse_dim = std::max(collapse_dim, (int64_t)-1);
collapse_shape_vec.push_back(collapse_dim);
} else {
collapse_shape_vec.push_back(repeat_shape_dims.d[k]);
}
}
auto collapse = ctx->net->addShuffle(*slice->getOutput(0));
TORCHTRT_CHECK(collapse, "Unable to create shuffle layer from node: " << *n);
collapse->setReshapeDimensions(util::toDims(collapse_shape_vec));

collapse->setName(util::node_info(n).c_str());
auto out_tensor = ctx->AssociateValueAndTensor(n->outputs()[0], collapse->getOutput(0));
LOG_DEBUG("Output tensor shape: " << out_tensor->getDimensions());

return true;
}});

Expand Down
224 changes: 224 additions & 0 deletions tests/core/conversion/converters/test_expand.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -445,3 +445,227 @@ TEST(Converters, ATenRepeatExtraDimsConvertsCorrectlyWithDynamicInput) {

ASSERT_TRUE(torch_tensorrt::tests::util::almostEqual(jit_results[0], trt, 2e-6));
}

TEST(Converters, ATenRepeatInterleaveScalarDimConvertsCorrectly) {
const auto graph = R"IR(
graph(%x.1 : Tensor):
%2 : int = prim::Constant[value=3]()
%3 : int = prim::Constant[value=1]()
%4 : None = prim::Constant()
%5 : Tensor = aten::repeat_interleave(%x.1, %2, %3, %4)
return (%5))IR";

auto g = std::make_shared<torch::jit::Graph>();

torch::jit::parseIR(graph, g.get());

auto in = at::randint(1, 10, {1, 3}, {at::kCUDA});

auto jit_in = at::clone(in);
auto params = torch_tensorrt::core::ir::get_static_params(g->inputs(), {});
auto jit_results = torch_tensorrt::tests::util::RunGraph(g, params, {jit_in});

auto trt_in = at::clone(in);
params = torch_tensorrt::core::ir::get_static_params(g->inputs(), {});
auto trt_results = torch_tensorrt::tests::util::RunGraphEngine(g, params, {trt_in});

auto trt = trt_results[0].reshape(jit_results[0].sizes());

ASSERT_TRUE(torch_tensorrt::tests::util::almostEqual(jit_results[0], trt, 2e-6));
}

TEST(Converters, ATenRepeatInterleaveScalarDimConvertsCorrectlyWithDynamicInput) {
const auto graph = R"IR(
graph(%x.1 : Tensor):
%2 : int = prim::Constant[value=3]()
%3 : int = prim::Constant[value=1]()
%4 : None = prim::Constant()
%5 : Tensor = aten::repeat_interleave(%x.1, %2, %3, %4)
return (%5))IR";

auto g = std::make_shared<torch::jit::Graph>();

torch::jit::parseIR(graph, g.get());

auto in = at::randint(1, 10, {1, 3}, {at::kCUDA});

auto jit_in = at::clone(in);
auto params = torch_tensorrt::core::ir::get_static_params(g->inputs(), {});
auto jit_results = torch_tensorrt::tests::util::RunGraph(g, params, {jit_in});

auto trt_in = at::clone(in);
params = torch_tensorrt::core::ir::get_static_params(g->inputs(), {});
auto trt_results = torch_tensorrt::tests::util::RunGraphEngineDynamic(g, params, {trt_in});

auto trt = trt_results[0].reshape(jit_results[0].sizes());

ASSERT_TRUE(torch_tensorrt::tests::util::almostEqual(jit_results[0], trt, 2e-6));
}

TEST(Converters, ATenRepeatInterleaveScalarNoDimConvertsCorrectly) {
const auto graph = R"IR(
graph(%x.1 : Tensor):
%2 : int = prim::Constant[value=3]()
%3 : None = prim::Constant()
%4 : None = prim::Constant()
%5 : Tensor = aten::repeat_interleave(%x.1, %2, %3, %4)
return (%5))IR";

auto g = std::make_shared<torch::jit::Graph>();

torch::jit::parseIR(graph, g.get());

auto in = at::randint(1, 10, {1, 3}, {at::kCUDA});

auto jit_in = at::clone(in);
auto params = torch_tensorrt::core::ir::get_static_params(g->inputs(), {});
auto jit_results = torch_tensorrt::tests::util::RunGraph(g, params, {jit_in});

auto trt_in = at::clone(in);
params = torch_tensorrt::core::ir::get_static_params(g->inputs(), {});
auto trt_results = torch_tensorrt::tests::util::RunGraphEngine(g, params, {trt_in});

auto trt = trt_results[0].reshape(jit_results[0].sizes());

ASSERT_TRUE(torch_tensorrt::tests::util::almostEqual(jit_results[0], trt, 2e-6));
}

TEST(Converters, ATenRepeatInterleaveScalarNoDimConvertsCorrectlyWithDynamicInput) {
const auto graph = R"IR(
graph(%x.1 : Tensor):
%2 : int = prim::Constant[value=3]()
%3 : None = prim::Constant()
%4 : None = prim::Constant()
%5 : Tensor = aten::repeat_interleave(%x.1, %2, %3, %4)
return (%5))IR";

auto g = std::make_shared<torch::jit::Graph>();

torch::jit::parseIR(graph, g.get());

auto in = at::randint(1, 10, {1, 3}, {at::kCUDA});

auto jit_in = at::clone(in);
auto params = torch_tensorrt::core::ir::get_static_params(g->inputs(), {});
auto jit_results = torch_tensorrt::tests::util::RunGraph(g, params, {jit_in});

auto trt_in = at::clone(in);
params = torch_tensorrt::core::ir::get_static_params(g->inputs(), {});
auto trt_results = torch_tensorrt::tests::util::RunGraphEngineDynamic(g, params, {trt_in});

auto trt = trt_results[0].reshape(jit_results[0].sizes());

ASSERT_TRUE(torch_tensorrt::tests::util::almostEqual(jit_results[0], trt, 2e-6));
}

TEST(Converters, ATenRepeatInterleave3dScalarDimConvertsCorrectly) {
const auto graph = R"IR(
graph(%x.1 : Tensor):
%2 : int = prim::Constant[value=3]()
%3 : int = prim::Constant[value=1]()
%4 : None = prim::Constant()
%5 : Tensor = aten::repeat_interleave(%x.1, %2, %3, %4)
return (%5))IR";

auto g = std::make_shared<torch::jit::Graph>();

torch::jit::parseIR(graph, g.get());

auto in = at::randint(1, 10, {2, 3, 2}, {at::kCUDA});

auto jit_in = at::clone(in);
auto params = torch_tensorrt::core::ir::get_static_params(g->inputs(), {});
auto jit_results = torch_tensorrt::tests::util::RunGraph(g, params, {jit_in});

auto trt_in = at::clone(in);
params = torch_tensorrt::core::ir::get_static_params(g->inputs(), {});
auto trt_results = torch_tensorrt::tests::util::RunGraphEngine(g, params, {trt_in});

auto trt = trt_results[0].reshape(jit_results[0].sizes());

ASSERT_TRUE(torch_tensorrt::tests::util::almostEqual(jit_results[0], trt, 2e-6));
}

TEST(Converters, ATenRepeatInterleave3dScalarDimConvertsCorrectlyWithDynamicInput) {
const auto graph = R"IR(
graph(%x.1 : Tensor):
%2 : int = prim::Constant[value=3]()
%3 : int = prim::Constant[value=1]()
%4 : None = prim::Constant()
%5 : Tensor = aten::repeat_interleave(%x.1, %2, %3, %4)
return (%5))IR";

auto g = std::make_shared<torch::jit::Graph>();

torch::jit::parseIR(graph, g.get());

auto in = at::randint(1, 10, {2, 3, 2}, {at::kCUDA});

auto jit_in = at::clone(in);
auto params = torch_tensorrt::core::ir::get_static_params(g->inputs(), {});
auto jit_results = torch_tensorrt::tests::util::RunGraph(g, params, {jit_in});

auto trt_in = at::clone(in);
params = torch_tensorrt::core::ir::get_static_params(g->inputs(), {});
auto trt_results = torch_tensorrt::tests::util::RunGraphEngineDynamic(g, params, {trt_in});

auto trt = trt_results[0].reshape(jit_results[0].sizes());

ASSERT_TRUE(torch_tensorrt::tests::util::almostEqual(jit_results[0], trt, 2e-6));
}

TEST(Converters, ATenRepeatInterleave3dScalarNoDimConvertsCorrectly) {
const auto graph = R"IR(
graph(%x.1 : Tensor):
%2 : int = prim::Constant[value=3]()
%3 : None = prim::Constant()
%4 : None = prim::Constant()
%5 : Tensor = aten::repeat_interleave(%x.1, %2, %3, %4)
return (%5))IR";

auto g = std::make_shared<torch::jit::Graph>();

torch::jit::parseIR(graph, g.get());

auto in = at::randint(1, 10, {2, 3, 2}, {at::kCUDA});

auto jit_in = at::clone(in);
auto params = torch_tensorrt::core::ir::get_static_params(g->inputs(), {});
auto jit_results = torch_tensorrt::tests::util::RunGraph(g, params, {jit_in});

auto trt_in = at::clone(in);
params = torch_tensorrt::core::ir::get_static_params(g->inputs(), {});
auto trt_results = torch_tensorrt::tests::util::RunGraphEngine(g, params, {trt_in});

auto trt = trt_results[0].reshape(jit_results[0].sizes());

ASSERT_TRUE(torch_tensorrt::tests::util::almostEqual(jit_results[0], trt, 2e-6));
}

TEST(Converters, ATenRepeatInterleave3dScalarNoDimConvertsCorrectlyWithDynamicInput) {
const auto graph = R"IR(
graph(%x.1 : Tensor):
%2 : int = prim::Constant[value=3]()
%3 : None = prim::Constant()
%4 : None = prim::Constant()
%5 : Tensor = aten::repeat_interleave(%x.1, %2, %3, %4)
return (%5))IR";

auto g = std::make_shared<torch::jit::Graph>();

torch::jit::parseIR(graph, g.get());

auto in = at::randint(1, 10, {2, 3, 2}, {at::kCUDA});

auto jit_in = at::clone(in);
auto params = torch_tensorrt::core::ir::get_static_params(g->inputs(), {});
auto jit_results = torch_tensorrt::tests::util::RunGraph(g, params, {jit_in});

auto trt_in = at::clone(in);
params = torch_tensorrt::core::ir::get_static_params(g->inputs(), {});
auto trt_results = torch_tensorrt::tests::util::RunGraphEngineDynamic(g, params, {trt_in});

auto trt = trt_results[0].reshape(jit_results[0].sizes());

ASSERT_TRUE(torch_tensorrt::tests::util::almostEqual(jit_results[0], trt, 2e-6));
}

0 comments on commit f350699

Please sign in to comment.