From 01e6541e7a54c3ce76f4ebefb6200d82787768e4 Mon Sep 17 00:00:00 2001 From: inocsin Date: Thu, 11 Aug 2022 11:31:13 +0800 Subject: [PATCH] feat: support scatter.value and scatter.src Signed-off-by: inocsin --- core/conversion/converters/converter_util.cpp | 21 ++++++ core/conversion/converters/converter_util.h | 2 + .../converters/impl/element_wise.cpp | 21 +----- core/conversion/converters/impl/select.cpp | 47 ++++++++++++ core/lowering/passes/op_aliasing.cpp | 18 ++++- .../conversion/converters/test_select.cpp | 73 +++++++++++++++++++ 6 files changed, 158 insertions(+), 24 deletions(-) diff --git a/core/conversion/converters/converter_util.cpp b/core/conversion/converters/converter_util.cpp index 94ac827ef4..5690b8f3e4 100644 --- a/core/conversion/converters/converter_util.cpp +++ b/core/conversion/converters/converter_util.cpp @@ -335,6 +335,27 @@ nvinfer1::ITensor* get_slice_size( return size_itensor; } +nvinfer1::ITensor* scalar_to_tensor(ConversionCtx* ctx, at::Scalar s) { + nvinfer1::ITensor* out; + if (s.isIntegral(false)) { + auto s_int = s.to(); + auto s_t = torch::tensor({s_int}).to(at::kInt); + out = tensor_to_const(ctx, s_t); + } else if (s.isBoolean()) { + auto s_bool = s.to(); + auto s_t = torch::tensor({s_bool}).to(at::kBool); + out = tensor_to_const(ctx, s_t); + } else if (s.isFloatingPoint()) { + auto other_float = s.to(); + auto s_t = torch::tensor({other_float}); + out = tensor_to_const(ctx, s_t); + } else { + out = nullptr; + TORCHTRT_THROW_ERROR("Unsupported data type for scalar. Found: (" << s.type() << ")"); + } + return out; +} + } // namespace converters } // namespace conversion } // namespace core diff --git a/core/conversion/converters/converter_util.h b/core/conversion/converters/converter_util.h index b155499858..2f5d4b25a9 100644 --- a/core/conversion/converters/converter_util.h +++ b/core/conversion/converters/converter_util.h @@ -80,6 +80,8 @@ nvinfer1::ITensor* get_slice_size( int nbdims, std::string const& name); +nvinfer1::ITensor* scalar_to_tensor(ConversionCtx* ctx, at::Scalar s); + } // namespace converters } // namespace conversion } // namespace core diff --git a/core/conversion/converters/impl/element_wise.cpp b/core/conversion/converters/impl/element_wise.cpp index f2770508ca..20c1112f0a 100644 --- a/core/conversion/converters/impl/element_wise.cpp +++ b/core/conversion/converters/impl/element_wise.cpp @@ -25,26 +25,7 @@ nvinfer1::ITensor* clamp_util( return clamp_layer_out; } -nvinfer1::ITensor* scalar_to_tensor(ConversionCtx* ctx, at::Scalar s) { - nvinfer1::ITensor* out; - if (s.isIntegral(false)) { - auto s_int = s.to(); - auto s_t = torch::tensor({s_int}).to(at::kInt); - out = tensor_to_const(ctx, s_t); - } else if (s.isBoolean()) { - auto s_bool = s.to(); - auto s_t = torch::tensor({s_bool}).to(at::kBool); - out = tensor_to_const(ctx, s_t); - } else if (s.isFloatingPoint()) { - auto other_float = s.to(); - auto s_t = torch::tensor({other_float}); - out = tensor_to_const(ctx, s_t); - } else { - out = nullptr; - TORCHTRT_THROW_ERROR("Unsupported data type for scalar. Found: (" << s.type() << ")"); - } - return out; -} + auto element_wise_registrations TORCHTRT_UNUSED = RegisterNodeConversionPatterns() diff --git a/core/conversion/converters/impl/select.cpp b/core/conversion/converters/impl/select.cpp index 31814a682b..20a03f6f5e 100644 --- a/core/conversion/converters/impl/select.cpp +++ b/core/conversion/converters/impl/select.cpp @@ -464,6 +464,53 @@ auto select_registrations TORCHTRT_UNUSED = auto out_tensor = ctx->AssociateValueAndTensor(n->outputs()[0], new_layer->getOutput(0)); LOG_DEBUG("Output shape: " << out_tensor->getDimensions()); return true; + }}) + .pattern( + {"aten::scatter.value(Tensor self, int dim, Tensor index, Scalar value) -> (Tensor)", + [](ConversionCtx* ctx, const torch::jit::Node* n, args& args) -> bool { + auto self = args[0].ITensorOrFreeze(ctx); + int dim = args[1].unwrapToInt(); + auto index = args[2].ITensorOrFreeze(ctx); + auto index_dim = index->getDimensions(); + std::vector dim_vec; + for (int i = 0; i < index_dim.nbDims; i++) { + dim_vec.push_back(index_dim.d[i]); + } + auto value = args[3].unwrapToScalar() * torch::ones(dim_vec); + auto value_tensor = tensor_to_const(ctx, value, ""); + if (self->getType() != value_tensor->getType()) { + value_tensor = castITensor(ctx, value_tensor, self->getType()); + } + + auto layer = ctx->net->addScatter(*self, *index, *value_tensor, nvinfer1::ScatterMode::kELEMENT); + layer->setAxis(dim); + + TORCHTRT_CHECK(layer, "Unable to create layer for aten::scatter.value"); + + layer->setName(util::node_info(n).c_str()); + + auto out_tensor = ctx->AssociateValueAndTensor(n->outputs()[0], layer->getOutput(0)); + LOG_DEBUG("Output shape: " << out_tensor->getDimensions()); + return true; + }}) + .pattern( + {"aten::scatter.src(Tensor self, int dim, Tensor index, Tensor src) -> (Tensor)", + [](ConversionCtx* ctx, const torch::jit::Node* n, args& args) -> bool { + auto self = args[0].ITensorOrFreeze(ctx); + int dim = args[1].unwrapToInt(); + auto index = args[2].ITensorOrFreeze(ctx); + auto src = args[3].ITensorOrFreeze(ctx); + + auto layer = ctx->net->addScatter(*self, *index, *src, nvinfer1::ScatterMode::kELEMENT); + layer->setAxis(dim); + + TORCHTRT_CHECK(layer, "Unable to create layer for aten::scatter.src"); + + layer->setName(util::node_info(n).c_str()); + + auto out_tensor = ctx->AssociateValueAndTensor(n->outputs()[0], layer->getOutput(0)); + LOG_DEBUG("Output shape: " << out_tensor->getDimensions()); + return true; }}); } // namespace diff --git a/core/lowering/passes/op_aliasing.cpp b/core/lowering/passes/op_aliasing.cpp index 79ebaf6a02..6be85be534 100644 --- a/core/lowering/passes/op_aliasing.cpp +++ b/core/lowering/passes/op_aliasing.cpp @@ -16,15 +16,25 @@ void AliasOperators(std::shared_ptr& graph) { graph(%s, %o): %1 : Tensor = aten::div(%s, %o) return (%1))IR"; - ; - - // TODO - // complete other element wise pass torch::jit::SubgraphRewriter true_divide_to_div; true_divide_to_div.RegisterRewritePattern(true_divide_pattern, div_pattern); true_divide_to_div.runOnGraph(graph); LOG_GRAPH("Post map true_divide -> div: " << *graph); + + std::string scatter_sub_pattern = R"IR( + graph(%data, %dim, %index, %value): + %o : Tensor = aten::scatter_(%data, %dim, %index, %value) + return (%o))IR"; + std::string scatter_pattern = R"IR( + graph(%data, %dim, %index, %value): + %o : Tensor = aten::scatter(%data, %dim, %index, %value) + return (%o))IR"; + + torch::jit::SubgraphRewriter rewrite_scatter; + rewrite_scatter.RegisterRewritePattern(scatter_sub_pattern, scatter_pattern); + rewrite_scatter.runOnGraph(graph); + LOG_GRAPH("Post map scatter_ -> scatter: " << *graph); } } // namespace passes diff --git a/tests/core/conversion/converters/test_select.cpp b/tests/core/conversion/converters/test_select.cpp index 67b760aa24..d77fa37d40 100644 --- a/tests/core/conversion/converters/test_select.cpp +++ b/tests/core/conversion/converters/test_select.cpp @@ -855,3 +855,76 @@ TEST(Converters, ATenUnbindNegativeAxisConvertsCorrectly) { ASSERT_TRUE(torch_tensorrt::tests::util::almostEqual(jit_results[i], trt, 2e-6)); } } + +TEST(Converters, ScatterValueConvertsCorrectly) { + const auto graph = R"IR( + graph(%data : Tensor, + %index.1 : Tensor): + %value : int = prim::Constant[value=100]() + %dim : int = prim::Constant[value=1]() + %5 : NoneType = prim::Constant() + %6 : bool = prim::Constant[value=0]() + %7 : int = prim::Constant[value=4]() + %index : Tensor = aten::to(%index.1, %7, %6, %6, %5) + %10 : Tensor = aten::scatter(%data, %dim, %index, %value) + return (%10))IR"; + + auto g = std::make_shared(); + + torch::jit::parseIR(graph, g.get()); + + auto index = at::randint(0, 5, {2, 2}, {at::kCUDA}); + auto data = at::randn({5, 5}, {at::kCUDA}); + + auto jit_index = at::clone(index); + auto jit_data = at::clone(data); + auto params = torch_tensorrt::core::ir::get_static_params(g->inputs(), {}); + auto jit_results = torch_tensorrt::tests::util::RunGraph(g, params, {jit_data, jit_index}); + + auto trt_index = at::clone(index); + auto trt_data = at::clone(data); + auto trt_results = torch_tensorrt::tests::util::RunGraphEngine(g, params, {trt_data, trt_index}); + + for (size_t i = 0; i < jit_results.size(); i++) { + auto trt = trt_results[i].reshape(jit_results[i].sizes()); + ASSERT_TRUE(torch_tensorrt::tests::util::almostEqual(jit_results[i], trt, 2e-6)); + } +} + +TEST(Converters, ScatterSrcConvertsCorrectly) { + const auto graph = R"IR( + graph(%data : Tensor, + %src : Tensor, + %index.1 : Tensor): + %dim : int = prim::Constant[value=1]() + %5 : NoneType = prim::Constant() + %6 : bool = prim::Constant[value=0]() + %7 : int = prim::Constant[value=4]() + %index : Tensor = aten::to(%index.1, %7, %6, %6, %5) + %10 : Tensor = aten::scatter(%data, %dim, %index, %src) + return (%10))IR"; + + auto g = std::make_shared(); + + torch::jit::parseIR(graph, g.get()); + + auto index = at::randint(0, 4, {2, 2}, {at::kCUDA}); + auto data = at::randn({5, 5}, {at::kCUDA}); + auto src = at::randn({2, 2}, {at::kCUDA}); + + auto jit_index = at::clone(index); + auto jit_data = at::clone(data); + auto jit_src = at::clone(src); + auto params = torch_tensorrt::core::ir::get_static_params(g->inputs(), {}); + auto jit_results = torch_tensorrt::tests::util::RunGraph(g, params, {jit_data, jit_src, jit_index}); + + auto trt_index = at::clone(index); + auto trt_data = at::clone(data); + auto trt_src = at::clone(src); + auto trt_results = torch_tensorrt::tests::util::RunGraphEngine(g, params, {trt_data, trt_src, trt_index}); + + for (size_t i = 0; i < jit_results.size(); i++) { + auto trt = trt_results[i].reshape(jit_results[i].sizes()); + ASSERT_TRUE(torch_tensorrt::tests::util::almostEqual(jit_results[i], trt, 2e-6)); + } +} \ No newline at end of file