Skip to content

Commit

Permalink
refactor(aten::embedding): Refactor with linting
Browse files Browse the repository at this point in the history
Signed-off-by: Naren Dasan <naren@narendasan.com>
Signed-off-by: Naren Dasan <narens@nvidia.com>
  • Loading branch information
narendasan committed Nov 10, 2020
1 parent 66154c7 commit b39bcbc
Show file tree
Hide file tree
Showing 2 changed files with 39 additions and 39 deletions.
38 changes: 19 additions & 19 deletions core/conversion/converters/impl/select.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -120,25 +120,25 @@ auto select_registrations TRTORCH_UNUSED =

return true;
}})
.pattern({
"aten::embedding(Tensor weight, Tensor indices, int padding_idx=-1, bool scale_grad_by_freq=False, bool sparse=False) -> (Tensor)",
[](ConversionCtx* ctx, const torch::jit::Node* n, args& args) -> bool {
auto embeddingTensor = args[0].ITensorOrFreeze(ctx);
auto indicesTensor = args[1].ITensor();
// Set datatype for indices tensor to INT32
indicesTensor->setType(nvinfer1::DataType::kINT32);
// IGatherLayer takes in input tensor, the indices, and the axis of input tensor to take indices from
auto gather_layer = ctx->net->addGather(*embeddingTensor, *indicesTensor, 0);
TRTORCH_CHECK(gather_layer, "Unable to create gather layer from node: " << *n);
auto gather_out = gather_layer->getOutput(0);

auto out = ctx->AssociateValueAndTensor(n->outputs()[0], gather_out);

LOG_DEBUG("Output tensor shape: " << out->getDimensions());

return true;
}});
.pattern(
{"aten::embedding(Tensor weight, Tensor indices, int padding_idx=-1, bool scale_grad_by_freq=False, bool sparse=False) -> (Tensor)",
[](ConversionCtx* ctx, const torch::jit::Node* n, args& args) -> bool {
auto embeddingTensor = args[0].ITensorOrFreeze(ctx);
auto indicesTensor = args[1].ITensor();
// Set datatype for indices tensor to INT32
indicesTensor->setType(nvinfer1::DataType::kINT32);

// IGatherLayer takes in input tensor, the indices, and the axis of input tensor to take indices from
auto gather_layer = ctx->net->addGather(*embeddingTensor, *indicesTensor, 0);
TRTORCH_CHECK(gather_layer, "Unable to create gather layer from node: " << *n);
auto gather_out = gather_layer->getOutput(0);

auto out = ctx->AssociateValueAndTensor(n->outputs()[0], gather_out);

LOG_DEBUG("Output tensor shape: " << out->getDimensions());

return true;
}});

} // namespace
} // namespace impl
Expand Down
40 changes: 20 additions & 20 deletions tests/core/converters/test_select.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -86,29 +86,29 @@ TEST(Converters, ATenNarrowStartScalarConvertsCorrectly) {
}

TEST(Converters, ATenEmbeddingConvertsCorrectly) {
const auto graph = R"IR(
const auto graph = R"IR(
graph(%1 : Tensor, %emb_weight : Float(10:3, 3:1)):
%2 : bool = prim::Constant[value=0]()
%3 : int = prim::Constant[value=-1]()
%5 : Tensor = aten::embedding(%emb_weight, %1, %3, %2, %2)
return (%5))IR";
auto g = std::make_shared<torch::jit::Graph>();
// Run Pytorch
torch::jit::parseIR(graph, &*g);
auto options_pyt = torch::TensorOptions().device(torch::kCUDA, 0).dtype(torch::kLong);
auto jit_in = at::tensor({0, 1, 2}, options_pyt);
auto embWeight = at::randn({10, 3}, {at::kCUDA});
auto params = trtorch::core::conversion::get_named_params(g->inputs(), {embWeight});
auto jit_results = trtorch::tests::util::RunGraph(g, params, {jit_in});
// Run TensorRT
auto options_trt = torch::TensorOptions().device(torch::kCUDA, 0).dtype(torch::kI32);
auto trt_in = at::tensor({0, 1, 2}, options_trt);
auto trt_results = trtorch::tests::util::RunGraphEngine(g, params, {trt_in});
auto trt = trt_results[0].reshape(jit_results[0].sizes());

ASSERT_TRUE(trtorch::tests::util::almostEqual(jit_results[0], trt, 2e-6));

auto g = std::make_shared<torch::jit::Graph>();

// Run Pytorch
torch::jit::parseIR(graph, &*g);
auto options_pyt = torch::TensorOptions().device(torch::kCUDA, 0).dtype(torch::kLong);
auto jit_in = at::tensor({0, 1, 2}, options_pyt);
auto embWeight = at::randn({10, 3}, {at::kCUDA});

auto params = trtorch::core::conversion::get_named_params(g->inputs(), {embWeight});
auto jit_results = trtorch::tests::util::RunGraph(g, params, {jit_in});

// Run TensorRT
auto options_trt = torch::TensorOptions().device(torch::kCUDA, 0).dtype(torch::kI32);
auto trt_in = at::tensor({0, 1, 2}, options_trt);
auto trt_results = trtorch::tests::util::RunGraphEngine(g, params, {trt_in});
auto trt = trt_results[0].reshape(jit_results[0].sizes());

ASSERT_TRUE(trtorch::tests::util::almostEqual(jit_results[0], trt, 2e-6));
}

0 comments on commit b39bcbc

Please sign in to comment.