From b31b7db62b5802fd2ca823b87afeb7b9d460745b Mon Sep 17 00:00:00 2001 From: Michael Feliz Date: Mon, 17 Apr 2023 18:16:43 -0700 Subject: [PATCH 1/2] Fix dependency order of inserted long input casts --- core/lowering/lowering.cpp | 6 ++-- tests/core/lowering/BUILD | 4 +++ .../lowering/test_autocast_long_inputs.cpp | 36 +++++++++++++++++++ 3 files changed, 44 insertions(+), 2 deletions(-) create mode 100644 tests/core/lowering/test_autocast_long_inputs.cpp diff --git a/core/lowering/lowering.cpp b/core/lowering/lowering.cpp index 0683956e5a..fa179a3922 100644 --- a/core/lowering/lowering.cpp +++ b/core/lowering/lowering.cpp @@ -31,6 +31,8 @@ int AutocastLongInputs( ir::TypeMap input_type_map, std::string target_device_name) { int num_autocasts = 0; + auto old_insert_point = g->insertPoint(); + g->setInsertPoint(g->nodes().front()); // For each graph input, determine if it can be autocasted for (size_t i = 0; i < g->inputs().size(); i++) { auto input = g->inputs()[i]; @@ -71,7 +73,7 @@ int AutocastLongInputs( auto cast_node = g->create(torch::jit::aten::to, {input, cuda, const_type, const_false, const_false, none_val}); // Replace all uses of the original tensor with that of the casted tensor - g->prependNode(cast_node); + g->insertNode(cast_node); input->replaceAllUsesAfterNodeWith(cast_node, cast_node->outputs()[0]); // Mark the cast node to run in PyTorch for ease of casting @@ -80,7 +82,7 @@ int AutocastLongInputs( num_autocasts++; } } - + g->setInsertPoint(old_insert_point); LOG_GRAPH("Inserted " << num_autocasts << " autocasts"); if (num_autocasts > 0) { diff --git a/tests/core/lowering/BUILD b/tests/core/lowering/BUILD index 801c6009c9..a0650c032f 100644 --- a/tests/core/lowering/BUILD +++ b/tests/core/lowering/BUILD @@ -27,6 +27,10 @@ cc_test( }), ) +lowering_test( + name = "test_autocast_long_inputs", +) + lowering_test( name = "test_conv_pass", ) diff --git a/tests/core/lowering/test_autocast_long_inputs.cpp b/tests/core/lowering/test_autocast_long_inputs.cpp new file mode 100644 index 0000000000..b0f7e235ec --- /dev/null +++ b/tests/core/lowering/test_autocast_long_inputs.cpp @@ -0,0 +1,36 @@ +#include +#include "core/compiler.h" +#include "core/lowering/passes/passes.h" +#include "gtest/gtest.h" +#include "tests/util/util.h" +#include "torch/csrc/jit/ir/irparser.h" +#include "torch/csrc/jit/ir/subgraph_matcher.h" + +TEST(LoweringPasses, AutocastLongInputs) { + std::string source_graph = R"IR( + graph(%long_0 : Tensor, %long_1 : Tensor): + %res : Tensor = aten::add(%long_0, %long_1) + return (%res))IR"; + std::string target_graph = R"IR( + graph(%long_0 : Tensor, %long_1 : Tensor): + %3 : bool = prim::Constant[value=0]() + %4 : Device = prim::Constant[value="cuda:0"]() + %5 : NoneType = prim::Constant() + %6 : int = prim::Constant[value=4]() + %7 : Tensor = aten::to[to_compile=0](%long_0, %4, %6, %3, %3, %5) + %8 : int = prim::Constant[value=4]() + %9 : Tensor = aten::to[to_compile=0](%long_1, %4, %8, %3, %3, %5) + %2 : Tensor = aten::add(%7, %9) + return (%2))IR"; + + auto sg = std::make_shared(); + torch::jit::parseIR(source_graph, &*sg); + std::unordered_map> type_map; + type_map[sg->inputs()[0]] = at::kLong; + type_map[sg->inputs()[1]] = at::kLong; + torch_tensorrt::core::lowering::AutocastLongInputs(sg, type_map, "cuda:0"); + auto tg = std::make_shared(); + torch::jit::parseIR(target_graph, &*tg); + ASSERT_TRUE(sg->nodes().front()->kind() == torch::jit::prim::Constant); // confirm constants are added before casts + ASSERT_TRUE(!torch::jit::findPatternMatches(*tg, *sg).empty()); +} From 2d9397f2cc5dd6652ce4e4b21609223a3dc57a73 Mon Sep 17 00:00:00 2001 From: Michael Feliz Date: Mon, 17 Apr 2023 19:52:18 -0700 Subject: [PATCH 2/2] Add autocast test to test suite --- tests/core/lowering/BUILD | 1 + 1 file changed, 1 insertion(+) diff --git a/tests/core/lowering/BUILD b/tests/core/lowering/BUILD index a0650c032f..081443ecb3 100644 --- a/tests/core/lowering/BUILD +++ b/tests/core/lowering/BUILD @@ -106,6 +106,7 @@ lowering_test( test_suite( name = "lowering_tests", tests = [ + ":test_autocast_long_inputs", ":test_conv_pass", ":test_device_casting", ":test_exception_elimination_pass",