Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: Add support for aten::meshgrid #1601

Merged
merged 1 commit into from
Jan 22, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
83 changes: 83 additions & 0 deletions core/conversion/converters/impl/expand.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -393,6 +393,89 @@ auto expand_registrations TORCHTRT_UNUSED =
auto out_tensor = ctx->AssociateValueAndTensor(n->outputs()[0], collapse->getOutput(0));
LOG_DEBUG("Output tensor shape: " << out_tensor->getDimensions());

return true;
}})
.pattern(
{"aten::meshgrid(Tensor[] tensors) -> (Tensor[])",
[](ConversionCtx* ctx, const torch::jit::Node* n, args& args) -> bool {
// torch.meshgrid only supports 1D or 0D input tensors
auto arg_tensors = args[0].IValue()->toListRef();
std::vector<nvinfer1::ITensor*> tensors;
for (auto t : arg_tensors) {
if (t.isTensor()) {
auto torch_tensor = t.toTensor();
tensors.push_back(tensor_to_const(ctx, torch_tensor));
} else {
auto cont = t.toCustomClass<TensorContainer>();
tensors.push_back(cont->tensor());
}
}

// build the output shape for all tensors in the output list
nvinfer1::Dims output_dims;
output_dims.nbDims = tensors.size();
for (size_t idx = 0UL; idx < tensors.size(); ++idx) {
auto dims = tensors[idx]->getDimensions();
output_dims.d[idx] = dims.nbDims == 0 ? 1 : dims.d[0];
}
std::vector<nvinfer1::ITensor*> out_tensors;
// Reshape tensors into output shape (reshape, expand)
for (size_t idx = 0UL; idx < tensors.size(); ++idx) {
auto t = tensors[idx];
auto dims = t->getDimensions();
nvinfer1::Dims reshape_dims;
reshape_dims.nbDims = tensors.size();
for (size_t reshape_idx = 0UL; reshape_idx < tensors.size(); ++reshape_idx) {
if (reshape_idx == idx) {
reshape_dims.d[reshape_idx] = dims.nbDims == 0 ? 1 : dims.d[0];
} else {
reshape_dims.d[reshape_idx] = 1;
}
}
// Add a reshape layer before expanding dims
auto reshape_layer = ctx->net->addShuffle(*t);
reshape_layer->setReshapeDimensions(reshape_dims);
std::stringstream reshape_layer_name;
reshape_layer_name << util::node_info(n) << "_meshgrid_reshape_" << std::to_string(idx);
reshape_layer->setName(reshape_layer_name.str().c_str());
auto reshaped = reshape_layer->getOutput(0);
LOG_DEBUG("Tensor " << idx << " reshaped to : " << reshaped->getDimensions() << " from " << dims);

// Add slice layer for expansion
std::vector<int64_t> start_vec(output_dims.nbDims, 0);
auto start_offset = util::toDims(c10::IntArrayRef(start_vec));

std::vector<int64_t> strides_vec(output_dims.nbDims, 0);
for (int64_t i = 0; i < output_dims.nbDims; i++) {
strides_vec[i] = (reshaped->getDimensions().d[i] != 1);
}

auto strides = util::toDims(c10::IntArrayRef(strides_vec));

auto slice_layer = ctx->net->addSlice(*reshaped, start_offset, output_dims, strides);
std::stringstream slice_layer_name;
slice_layer_name << util::node_info(n) << "_meshgrid_slice_" << std::to_string(idx);
slice_layer->setName(slice_layer_name.str().c_str());
auto slice_output = slice_layer->getOutput(0);
LOG_DEBUG("Tensor " << idx << " expanded to : " << slice_output->getDimensions());
out_tensors.push_back(slice_output);
}

// Pack output tensors into list
c10::ListTypePtr lt = n->output()->type()->expect<c10::ListType>();
c10::TypePtr elementType = lt->getElementType();
auto list = c10::impl::GenericList(elementType);
list.reserve(out_tensors.size());

for (auto t : out_tensors) {
auto tensor_holder = TensorContainer();
tensor_holder.hold_tensor(t);
auto ival = c10::IValue(std::move(c10::make_intrusive<TensorContainer>(tensor_holder)));
list.emplace_back(ival);
}

auto output_list = std::move(torch::jit::IValue(list));
ctx->AssociateValueAndIValue(n->outputs()[0], output_list);
return true;
}});

Expand Down
35 changes: 35 additions & 0 deletions tests/core/conversion/converters/test_expand.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -669,3 +669,38 @@ TEST(Converters, ATenRepeatInterleave3dScalarNoDimConvertsCorrectlyWithDynamicIn

ASSERT_TRUE(torch_tensorrt::tests::util::almostEqual(jit_results[0], trt, 2e-6));
}

TEST(Converters, ATenMeshGridConvertsCorrectly) {
const auto graph = R"IR(
graph(%x : Tensor, %y : Tensor, %z : Tensor):
%0 : Tensor[] = prim::ListConstruct(%x, %y, %z)
%1 : Tensor[] = aten::meshgrid(%0)
%x_0 : Tensor, %y_0 : Tensor, %z_0 : Tensor = prim::ListUnpack(%1)
return (%x_0, %y_0, %z_0))IR";

auto g = std::make_shared<torch::jit::Graph>();

torch::jit::parseIR(graph, g.get());

auto x = at::randint(1, 10, {2}, {at::kCUDA}).to(torch::kInt);
auto jit_x = at::clone(x);

auto y = at::randint(1, 10, {5}, {at::kCUDA}).to(torch::kInt);
auto jit_y = at::clone(y);

auto z = torch::tensor(22, {at::kCUDA}).to(torch::kInt); // 0D
auto jit_z = at::clone(z);

auto params = torch_tensorrt::core::ir::get_static_params(g->inputs(), {});
auto jit_results = torch_tensorrt::tests::util::RunGraph(g, params, {jit_x, jit_y, jit_z});

auto trt_x = at::clone(jit_x);
auto trt_y = at::clone(jit_y);
auto trt_z = at::clone(jit_z);
params = torch_tensorrt::core::ir::get_static_params(g->inputs(), {});
auto trt_results = torch_tensorrt::tests::util::RunGraphEngine(g, params, {trt_x, trt_y, trt_z});

ASSERT_TRUE(torch_tensorrt::tests::util::almostEqual(jit_results[0], trt_results[0], 2e-6));
ASSERT_TRUE(torch_tensorrt::tests::util::almostEqual(jit_results[1], trt_results[1], 2e-6));
ASSERT_TRUE(torch_tensorrt::tests::util::almostEqual(jit_results[2], trt_results[2], 2e-6));
}