Skip to content

Commit

Permalink
feat: support aten::arange converter
Browse files Browse the repository at this point in the history
Signed-off-by: inocsin <vcheungyi@163.com>
  • Loading branch information
inocsin committed Mar 19, 2021
1 parent 5b6bd4c commit 014e381
Show file tree
Hide file tree
Showing 2 changed files with 160 additions and 1 deletion.
58 changes: 57 additions & 1 deletion core/conversion/evaluators/aten.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -467,7 +467,63 @@ auto aten_registrations TRTORCH_UNUSED =
LOG_WARNING("Warning from TorchScript: " << *warning);
return {};
},
EvalOptions()});
EvalOptions()})
.evaluator({c10::Symbol::fromQualString("aten::arange"),
[](const torch::jit::Node* n, kwargs& args) -> c10::optional<torch::jit::IValue> {
// int end_scalar = 0;
// auto end_scalar = ceil(args.at(n->input(0)).unwrapToScalar());
int input_size = n->inputs().size();
int scalar_count = 0;
for (int i = 0; i < input_size; i++) {
if (args.at(n->input(i)).IValue()->isScalar()) {
scalar_count += 1;
}
}
if (scalar_count == 1) {
if (args.at(n->input(0)).IValue()->isInt()) {
int end_scalar = args.at(n->input(0)).unwrapToInt();
return torch::arange(end_scalar);
} else if (args.at(n->input(0)).IValue()->isDouble()) {
float end_scalar = ceil(args.at(n->input(0)).unwrapToScalar().to<float>());
return torch::arange(end_scalar);
}
} else if (scalar_count == 2) {
if (args.at(n->input(0)).IValue()->isDouble() || args.at(n->input(1)).IValue()->isDouble()) {
float start_scalar = args.at(n->input(0)).unwrapToScalar().to<float>();
float end_scalar = args.at(n->input(1)).unwrapToScalar().to<float>();
return torch::arange(start_scalar, end_scalar);
} else {
int start_scalar = args.at(n->input(0)).unwrapToInt();
int end_scalar = args.at(n->input(1)).unwrapToInt();
return torch::arange(start_scalar, end_scalar);
}
} else if (scalar_count == 3) {
if (args.at(n->input(0)).IValue()->isDouble() || args.at(n->input(1)).IValue()->isDouble() ||
args.at(n->input(2)).IValue()->isDouble()) {
float start_scalar = args.at(n->input(0)).unwrapToScalar().to<float>();
float end_scalar = args.at(n->input(1)).unwrapToScalar().to<float>();
float step_scalar = args.at(n->input(2)).unwrapToScalar().to<float>();
return torch::arange(start_scalar, end_scalar, step_scalar);
} else {
int start_scalar = args.at(n->input(0)).unwrapToInt();
int end_scalar = args.at(n->input(1)).unwrapToInt();
int step_scalar = args.at(n->input(2)).unwrapToInt();
return torch::arange(start_scalar, end_scalar, step_scalar);
}
} else {
TRTORCH_THROW_ERROR(
"Invalid input argument size for aten::arange, input argument size: " << input_size);
}
return {};
},
EvalOptions().validSchemas({
R"SIG(aten::arange(Scalar end, *, int? dtype=None, int? layout=None,
Device? device=None, bool? pin_memory=None) -> (Tensor))SIG",
R"SIG(aten::arange.start(Scalar start, Scalar end, *, ScalarType? dtype=None,
Layout? layout=None, Device? device=None, bool? pin_memory=None) -> (Tensor))SIG",
R"SIG(aten::arange.start_step(Scalar start, Scalar end, Scalar step, *, ScalarType? dtype=None,
Layout? layout=None, Device? device=None, bool? pin_memory=None) -> (Tensor))SIG",
})});
} // namespace
} // namespace evaluators
} // namespace conversion
Expand Down
103 changes: 103 additions & 0 deletions tests/core/conversion/evaluators/test_aten_evaluators.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -75,4 +75,107 @@ TEST(Evaluators, ZerosDataTypeEvaluatesCorrectly) {
auto trt_results = trtorch::tests::util::EvaluateGraph(g->block(), {in});

ASSERT_TRUE(at::equal(jit_results[0].toTensor().to(at::kCUDA), trt_results[0].toTensor()));
}

TEST(Evaluators, ATenArangeIntEvaluatesCorrectly) {
const auto graph = R"IR(
graph():
%0 : int = prim::Constant[value=51]()
%1 : None = prim::Constant()
%2 : Tensor = aten::arange(%0, %1, %1, %1, %1)
return (%2))IR";

auto g = std::make_shared<torch::jit::Graph>();
torch::jit::parseIR(graph, &*g);

auto jit_results = trtorch::tests::util::EvaluateGraphJIT(g, {});
auto trt_results = trtorch::tests::util::EvaluateGraph(g->block(), {});

ASSERT_TRUE(trtorch::tests::util::almostEqual(jit_results[0].toTensor(), trt_results[0].toTensor(), 2e-6));
}

TEST(Evaluators, ATenArangeFloatEvaluatesCorrectly) {
const auto graph = R"IR(
graph():
%0 : float = prim::Constant[value=51.2]()
%1 : None = prim::Constant()
%2 : Tensor = aten::arange(%0, %1, %1, %1, %1)
return (%2))IR";

auto g = std::make_shared<torch::jit::Graph>();
torch::jit::parseIR(graph, &*g);

auto jit_results = trtorch::tests::util::EvaluateGraphJIT(g, {});
auto trt_results = trtorch::tests::util::EvaluateGraph(g->block(), {});
ASSERT_TRUE(trtorch::tests::util::almostEqual(jit_results[0].toTensor(), trt_results[0].toTensor(), 2e-6));
}

TEST(Evaluators, ATenArangeStartEndIntEvaluatesCorrectly) {
const auto graph = R"IR(
graph():
%0 : int = prim::Constant[value=1]()
%1 : int = prim::Constant[value=51]()
%2 : None = prim::Constant()
%3 : Tensor = aten::arange(%0, %1, %2, %2, %2, %2)
return (%3))IR";

auto g = std::make_shared<torch::jit::Graph>();
torch::jit::parseIR(graph, &*g);

auto jit_results = trtorch::tests::util::EvaluateGraphJIT(g, {});
auto trt_results = trtorch::tests::util::EvaluateGraph(g->block(), {});
ASSERT_TRUE(trtorch::tests::util::almostEqual(jit_results[0].toTensor(), trt_results[0].toTensor(), 2e-6));
}

TEST(Evaluators, ATenArangeStartEndFloatEvaluatesCorrectly) {
const auto graph = R"IR(
graph():
%0 : float = prim::Constant[value=1.5]()
%1 : float = prim::Constant[value=51.2]()
%2 : None = prim::Constant()
%3 : Tensor = aten::arange(%0, %1, %2, %2, %2, %2)
return (%3))IR";

auto g = std::make_shared<torch::jit::Graph>();
torch::jit::parseIR(graph, &*g);

auto jit_results = trtorch::tests::util::EvaluateGraphJIT(g, {});
auto trt_results = trtorch::tests::util::EvaluateGraph(g->block(), {});
ASSERT_TRUE(trtorch::tests::util::almostEqual(jit_results[0].toTensor(), trt_results[0].toTensor(), 2e-6));
}

TEST(Evaluators, ATenArangeStartEndStepIntEvaluatesCorrectly) {
const auto graph = R"IR(
graph():
%0 : int = prim::Constant[value=1]()
%1 : int = prim::Constant[value=51]()
%2 : int = prim::Constant[value=1]()
%3 : None = prim::Constant()
%4 : Tensor = aten::arange(%0, %1, %2, %3, %3, %3, %3)
return (%4))IR";

auto g = std::make_shared<torch::jit::Graph>();
torch::jit::parseIR(graph, &*g);

auto jit_results = trtorch::tests::util::EvaluateGraphJIT(g, {});
auto trt_results = trtorch::tests::util::EvaluateGraph(g->block(), {});
ASSERT_TRUE(trtorch::tests::util::almostEqual(jit_results[0].toTensor(), trt_results[0].toTensor(), 2e-6));
}

TEST(Evaluators, ATenArangeStartEndStepFloatEvaluatesCorrectly) {
const auto graph = R"IR(
graph():
%0 : float = prim::Constant[value=1.2]()
%1 : float = prim::Constant[value=51.6]()
%2 : float = prim::Constant[value=1.5]()
%3 : None = prim::Constant()
%4 : Tensor = aten::arange(%0, %1, %2, %3, %3, %3, %3)
return (%4))IR";

auto g = std::make_shared<torch::jit::Graph>();
torch::jit::parseIR(graph, &*g);

auto jit_results = trtorch::tests::util::EvaluateGraphJIT(g, {});
auto trt_results = trtorch::tests::util::EvaluateGraph(g->block(), {});
ASSERT_TRUE(trtorch::tests::util::almostEqual(jit_results[0].toTensor(), trt_results[0].toTensor(), 2e-6));
}

0 comments on commit 014e381

Please sign in to comment.