Skip to content

Commit

Permalink
Merge pull request #1252 from inocsin/scatter
Browse files Browse the repository at this point in the history
feat: support scatter.value and scatter.src
  • Loading branch information
narendasan authored Aug 15, 2022
2 parents a64a3ac + 01e6541 commit 298c3a3
Show file tree
Hide file tree
Showing 6 changed files with 158 additions and 24 deletions.
21 changes: 21 additions & 0 deletions core/conversion/converters/converter_util.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -363,6 +363,27 @@ nvinfer1::ITensor* get_slice_size(
return size_itensor;
}

nvinfer1::ITensor* scalar_to_tensor(ConversionCtx* ctx, at::Scalar s) {
nvinfer1::ITensor* out;
if (s.isIntegral(false)) {
auto s_int = s.to<int64_t>();
auto s_t = torch::tensor({s_int}).to(at::kInt);
out = tensor_to_const(ctx, s_t);
} else if (s.isBoolean()) {
auto s_bool = s.to<bool>();
auto s_t = torch::tensor({s_bool}).to(at::kBool);
out = tensor_to_const(ctx, s_t);
} else if (s.isFloatingPoint()) {
auto other_float = s.to<float>();
auto s_t = torch::tensor({other_float});
out = tensor_to_const(ctx, s_t);
} else {
out = nullptr;
TORCHTRT_THROW_ERROR("Unsupported data type for scalar. Found: (" << s.type() << ")");
}
return out;
}

} // namespace converters
} // namespace conversion
} // namespace core
Expand Down
2 changes: 2 additions & 0 deletions core/conversion/converters/converter_util.h
Original file line number Diff line number Diff line change
Expand Up @@ -80,6 +80,8 @@ nvinfer1::ITensor* get_slice_size(
int nbdims,
std::string const& name);

nvinfer1::ITensor* scalar_to_tensor(ConversionCtx* ctx, at::Scalar s);

} // namespace converters
} // namespace conversion
} // namespace core
Expand Down
21 changes: 1 addition & 20 deletions core/conversion/converters/impl/element_wise.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -25,26 +25,7 @@ nvinfer1::ITensor* clamp_util(
return clamp_layer_out;
}

nvinfer1::ITensor* scalar_to_tensor(ConversionCtx* ctx, at::Scalar s) {
nvinfer1::ITensor* out;
if (s.isIntegral(false)) {
auto s_int = s.to<int64_t>();
auto s_t = torch::tensor({s_int}).to(at::kInt);
out = tensor_to_const(ctx, s_t);
} else if (s.isBoolean()) {
auto s_bool = s.to<bool>();
auto s_t = torch::tensor({s_bool}).to(at::kBool);
out = tensor_to_const(ctx, s_t);
} else if (s.isFloatingPoint()) {
auto other_float = s.to<float>();
auto s_t = torch::tensor({other_float});
out = tensor_to_const(ctx, s_t);
} else {
out = nullptr;
TORCHTRT_THROW_ERROR("Unsupported data type for scalar. Found: (" << s.type() << ")");
}
return out;
}


auto element_wise_registrations TORCHTRT_UNUSED =
RegisterNodeConversionPatterns()
Expand Down
47 changes: 47 additions & 0 deletions core/conversion/converters/impl/select.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -464,6 +464,53 @@ auto select_registrations TORCHTRT_UNUSED =
auto out_tensor = ctx->AssociateValueAndTensor(n->outputs()[0], new_layer->getOutput(0));
LOG_DEBUG("Output shape: " << out_tensor->getDimensions());
return true;
}})
.pattern(
{"aten::scatter.value(Tensor self, int dim, Tensor index, Scalar value) -> (Tensor)",
[](ConversionCtx* ctx, const torch::jit::Node* n, args& args) -> bool {
auto self = args[0].ITensorOrFreeze(ctx);
int dim = args[1].unwrapToInt();
auto index = args[2].ITensorOrFreeze(ctx);
auto index_dim = index->getDimensions();
std::vector<int64_t> dim_vec;
for (int i = 0; i < index_dim.nbDims; i++) {
dim_vec.push_back(index_dim.d[i]);
}
auto value = args[3].unwrapToScalar() * torch::ones(dim_vec);
auto value_tensor = tensor_to_const(ctx, value, "");
if (self->getType() != value_tensor->getType()) {
value_tensor = castITensor(ctx, value_tensor, self->getType());
}

auto layer = ctx->net->addScatter(*self, *index, *value_tensor, nvinfer1::ScatterMode::kELEMENT);
layer->setAxis(dim);

TORCHTRT_CHECK(layer, "Unable to create layer for aten::scatter.value");

layer->setName(util::node_info(n).c_str());

auto out_tensor = ctx->AssociateValueAndTensor(n->outputs()[0], layer->getOutput(0));
LOG_DEBUG("Output shape: " << out_tensor->getDimensions());
return true;
}})
.pattern(
{"aten::scatter.src(Tensor self, int dim, Tensor index, Tensor src) -> (Tensor)",
[](ConversionCtx* ctx, const torch::jit::Node* n, args& args) -> bool {
auto self = args[0].ITensorOrFreeze(ctx);
int dim = args[1].unwrapToInt();
auto index = args[2].ITensorOrFreeze(ctx);
auto src = args[3].ITensorOrFreeze(ctx);

auto layer = ctx->net->addScatter(*self, *index, *src, nvinfer1::ScatterMode::kELEMENT);
layer->setAxis(dim);

TORCHTRT_CHECK(layer, "Unable to create layer for aten::scatter.src");

layer->setName(util::node_info(n).c_str());

auto out_tensor = ctx->AssociateValueAndTensor(n->outputs()[0], layer->getOutput(0));
LOG_DEBUG("Output shape: " << out_tensor->getDimensions());
return true;
}});

} // namespace
Expand Down
18 changes: 14 additions & 4 deletions core/lowering/passes/op_aliasing.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -16,15 +16,25 @@ void AliasOperators(std::shared_ptr<torch::jit::Graph>& graph) {
graph(%s, %o):
%1 : Tensor = aten::div(%s, %o)
return (%1))IR";
;

// TODO
// complete other element wise pass

torch::jit::SubgraphRewriter true_divide_to_div;
true_divide_to_div.RegisterRewritePattern(true_divide_pattern, div_pattern);
true_divide_to_div.runOnGraph(graph);
LOG_GRAPH("Post map true_divide -> div: " << *graph);

std::string scatter_sub_pattern = R"IR(
graph(%data, %dim, %index, %value):
%o : Tensor = aten::scatter_(%data, %dim, %index, %value)
return (%o))IR";
std::string scatter_pattern = R"IR(
graph(%data, %dim, %index, %value):
%o : Tensor = aten::scatter(%data, %dim, %index, %value)
return (%o))IR";

torch::jit::SubgraphRewriter rewrite_scatter;
rewrite_scatter.RegisterRewritePattern(scatter_sub_pattern, scatter_pattern);
rewrite_scatter.runOnGraph(graph);
LOG_GRAPH("Post map scatter_ -> scatter: " << *graph);
}

} // namespace passes
Expand Down
73 changes: 73 additions & 0 deletions tests/core/conversion/converters/test_select.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -855,3 +855,76 @@ TEST(Converters, ATenUnbindNegativeAxisConvertsCorrectly) {
ASSERT_TRUE(torch_tensorrt::tests::util::almostEqual(jit_results[i], trt, 2e-6));
}
}

TEST(Converters, ScatterValueConvertsCorrectly) {
const auto graph = R"IR(
graph(%data : Tensor,
%index.1 : Tensor):
%value : int = prim::Constant[value=100]()
%dim : int = prim::Constant[value=1]()
%5 : NoneType = prim::Constant()
%6 : bool = prim::Constant[value=0]()
%7 : int = prim::Constant[value=4]()
%index : Tensor = aten::to(%index.1, %7, %6, %6, %5)
%10 : Tensor = aten::scatter(%data, %dim, %index, %value)
return (%10))IR";

auto g = std::make_shared<torch::jit::Graph>();

torch::jit::parseIR(graph, g.get());

auto index = at::randint(0, 5, {2, 2}, {at::kCUDA});
auto data = at::randn({5, 5}, {at::kCUDA});

auto jit_index = at::clone(index);
auto jit_data = at::clone(data);
auto params = torch_tensorrt::core::ir::get_static_params(g->inputs(), {});
auto jit_results = torch_tensorrt::tests::util::RunGraph(g, params, {jit_data, jit_index});

auto trt_index = at::clone(index);
auto trt_data = at::clone(data);
auto trt_results = torch_tensorrt::tests::util::RunGraphEngine(g, params, {trt_data, trt_index});

for (size_t i = 0; i < jit_results.size(); i++) {
auto trt = trt_results[i].reshape(jit_results[i].sizes());
ASSERT_TRUE(torch_tensorrt::tests::util::almostEqual(jit_results[i], trt, 2e-6));
}
}

TEST(Converters, ScatterSrcConvertsCorrectly) {
const auto graph = R"IR(
graph(%data : Tensor,
%src : Tensor,
%index.1 : Tensor):
%dim : int = prim::Constant[value=1]()
%5 : NoneType = prim::Constant()
%6 : bool = prim::Constant[value=0]()
%7 : int = prim::Constant[value=4]()
%index : Tensor = aten::to(%index.1, %7, %6, %6, %5)
%10 : Tensor = aten::scatter(%data, %dim, %index, %src)
return (%10))IR";

auto g = std::make_shared<torch::jit::Graph>();

torch::jit::parseIR(graph, g.get());

auto index = at::randint(0, 4, {2, 2}, {at::kCUDA});
auto data = at::randn({5, 5}, {at::kCUDA});
auto src = at::randn({2, 2}, {at::kCUDA});

auto jit_index = at::clone(index);
auto jit_data = at::clone(data);
auto jit_src = at::clone(src);
auto params = torch_tensorrt::core::ir::get_static_params(g->inputs(), {});
auto jit_results = torch_tensorrt::tests::util::RunGraph(g, params, {jit_data, jit_src, jit_index});

auto trt_index = at::clone(index);
auto trt_data = at::clone(data);
auto trt_src = at::clone(src);
auto trt_results = torch_tensorrt::tests::util::RunGraphEngine(g, params, {trt_data, trt_src, trt_index});

for (size_t i = 0; i < jit_results.size(); i++) {
auto trt = trt_results[i].reshape(jit_results[i].sizes());
ASSERT_TRUE(torch_tensorrt::tests::util::almostEqual(jit_results[i], trt, 2e-6));
}
}

0 comments on commit 298c3a3

Please sign in to comment.