diff --git a/core/conversion/converters/impl/select.cpp b/core/conversion/converters/impl/select.cpp index bb829d06ae..8b8611e3a4 100644 --- a/core/conversion/converters/impl/select.cpp +++ b/core/conversion/converters/impl/select.cpp @@ -120,25 +120,25 @@ auto select_registrations TRTORCH_UNUSED = return true; }}) - .pattern({ - "aten::embedding(Tensor weight, Tensor indices, int padding_idx=-1, bool scale_grad_by_freq=False, bool sparse=False) -> (Tensor)", - [](ConversionCtx* ctx, const torch::jit::Node* n, args& args) -> bool { - auto embeddingTensor = args[0].ITensorOrFreeze(ctx); - auto indicesTensor = args[1].ITensor(); - // Set datatype for indices tensor to INT32 - indicesTensor->setType(nvinfer1::DataType::kINT32); - - // IGatherLayer takes in input tensor, the indices, and the axis of input tensor to take indices from - auto gather_layer = ctx->net->addGather(*embeddingTensor, *indicesTensor, 0); - TRTORCH_CHECK(gather_layer, "Unable to create gather layer from node: " << *n); - auto gather_out = gather_layer->getOutput(0); - - auto out = ctx->AssociateValueAndTensor(n->outputs()[0], gather_out); - - LOG_DEBUG("Output tensor shape: " << out->getDimensions()); - - return true; - }}); + .pattern( + {"aten::embedding(Tensor weight, Tensor indices, int padding_idx=-1, bool scale_grad_by_freq=False, bool sparse=False) -> (Tensor)", + [](ConversionCtx* ctx, const torch::jit::Node* n, args& args) -> bool { + auto embeddingTensor = args[0].ITensorOrFreeze(ctx); + auto indicesTensor = args[1].ITensor(); + // Set datatype for indices tensor to INT32 + indicesTensor->setType(nvinfer1::DataType::kINT32); + + // IGatherLayer takes in input tensor, the indices, and the axis of input tensor to take indices from + auto gather_layer = ctx->net->addGather(*embeddingTensor, *indicesTensor, 0); + TRTORCH_CHECK(gather_layer, "Unable to create gather layer from node: " << *n); + auto gather_out = gather_layer->getOutput(0); + + auto out = ctx->AssociateValueAndTensor(n->outputs()[0], gather_out); + + LOG_DEBUG("Output tensor shape: " << out->getDimensions()); + + return true; + }}); } // namespace } // namespace impl diff --git a/tests/core/converters/test_select.cpp b/tests/core/converters/test_select.cpp index a85ad5d754..4ab88b4ef7 100644 --- a/tests/core/converters/test_select.cpp +++ b/tests/core/converters/test_select.cpp @@ -86,29 +86,29 @@ TEST(Converters, ATenNarrowStartScalarConvertsCorrectly) { } TEST(Converters, ATenEmbeddingConvertsCorrectly) { - const auto graph = R"IR( + const auto graph = R"IR( graph(%1 : Tensor, %emb_weight : Float(10:3, 3:1)): %2 : bool = prim::Constant[value=0]() %3 : int = prim::Constant[value=-1]() %5 : Tensor = aten::embedding(%emb_weight, %1, %3, %2, %2) return (%5))IR"; - - auto g = std::make_shared(); - - // Run Pytorch - torch::jit::parseIR(graph, &*g); - auto options_pyt = torch::TensorOptions().device(torch::kCUDA, 0).dtype(torch::kLong); - auto jit_in = at::tensor({0, 1, 2}, options_pyt); - auto embWeight = at::randn({10, 3}, {at::kCUDA}); - - auto params = trtorch::core::conversion::get_named_params(g->inputs(), {embWeight}); - auto jit_results = trtorch::tests::util::RunGraph(g, params, {jit_in}); - - // Run TensorRT - auto options_trt = torch::TensorOptions().device(torch::kCUDA, 0).dtype(torch::kI32); - auto trt_in = at::tensor({0, 1, 2}, options_trt); - auto trt_results = trtorch::tests::util::RunGraphEngine(g, params, {trt_in}); - auto trt = trt_results[0].reshape(jit_results[0].sizes()); - - ASSERT_TRUE(trtorch::tests::util::almostEqual(jit_results[0], trt, 2e-6)); + + auto g = std::make_shared(); + + // Run Pytorch + torch::jit::parseIR(graph, &*g); + auto options_pyt = torch::TensorOptions().device(torch::kCUDA, 0).dtype(torch::kLong); + auto jit_in = at::tensor({0, 1, 2}, options_pyt); + auto embWeight = at::randn({10, 3}, {at::kCUDA}); + + auto params = trtorch::core::conversion::get_named_params(g->inputs(), {embWeight}); + auto jit_results = trtorch::tests::util::RunGraph(g, params, {jit_in}); + + // Run TensorRT + auto options_trt = torch::TensorOptions().device(torch::kCUDA, 0).dtype(torch::kI32); + auto trt_in = at::tensor({0, 1, 2}, options_trt); + auto trt_results = trtorch::tests::util::RunGraphEngine(g, params, {trt_in}); + auto trt = trt_results[0].reshape(jit_results[0].sizes()); + + ASSERT_TRUE(trtorch::tests::util::almostEqual(jit_results[0], trt, 2e-6)); } \ No newline at end of file