diff --git a/tests/core/lowering/test_conv1d_pass.cpp b/tests/core/lowering/test_conv1d_pass.cpp index 2b2f51306c..3694559108 100644 --- a/tests/core/lowering/test_conv1d_pass.cpp +++ b/tests/core/lowering/test_conv1d_pass.cpp @@ -5,6 +5,8 @@ #include "tests/util/util.h" #include "torch/csrc/jit/ir/irparser.h" #include "torch/csrc/jit/ir/subgraph_matcher.h" +#include "torch/csrc/jit/passes/canonicalize.h" +#include "torch/csrc/jit/passes/constant_pooling.h" TEST(LoweringPasses, Conv1dCorrectly) { const auto source_graph = R"IR( @@ -119,7 +121,7 @@ TEST(LoweringPasses, ConvTransposed1dCorrectly) { } TEST(LoweringPasses, Conv1dWithConditionalLowersCorrectly) { - const auto source_graph = R"IR( + std::string source_graph = R"IR( graph(%0 : Tensor, %1 : Float(4, 3, 3, strides=[9, 3, 1]), %2 : Float(3)): @@ -142,21 +144,21 @@ TEST(LoweringPasses, Conv1dWithConditionalLowersCorrectly) { -> (%res) return (%12))IR"; - const auto target_graph = R"IR( + std::string target_graph = R"IR( graph(%0 : Tensor, %1 : Float(4, 3, 3, strides=[9, 3, 1]), %2 : Float(3)): - %3 : bool = prim::Constant[value=0]() %4 : int = prim::Constant[value=0]() %5 : int = prim::Constant[value=1]() + %true : bool = prim::Constant[value=1]() + %3 : bool = prim::Constant[value=0]() + %output_padding : int[] = prim::Constant[value=[0]]() %6 : int = prim::Constant[value=1]() %stride : int[] = prim::ListConstruct(%6) %padding : int[] = prim::ListConstruct(%4) %dilation : int[] = prim::ListConstruct(%5) - %output_padding : int[] = prim::Constant[value=[0]]() # Add intentionally-invalid weight tensor to ensure prim::If blocks are respected - %true : bool = prim::Constant[value=1]() %invalid_weight : Tensor = aten::transpose(%0, %4, %5) %12 : Tensor = prim::If(%true) block0(): @@ -172,9 +174,16 @@ TEST(LoweringPasses, Conv1dWithConditionalLowersCorrectly) { auto sg = std::make_shared(); torch::jit::parseIR(source_graph, &*sg); torch_tensorrt::core::lowering::passes::Conv1DToConvolution(sg); + torch::jit::ConstantPooling(sg); + sg = torch::jit::Canonicalize(sg, false); auto tg = std::make_shared(); torch::jit::parseIR(target_graph, &*tg); + torch::jit::ConstantPooling(tg); + tg = torch::jit::Canonicalize(tg, false); + + // Validate identical graphs after pooling constants and canonicalizing + ASSERT_TRUE((tg->toString() == sg->toString())); auto in = at::randint(1, 2, {1, 3, 3}, {at::kCUDA}); auto w = at::randint(1, 2, {4, 3, 3}, {at::kCUDA}); diff --git a/tests/core/partitioning/test_conditionals.cpp b/tests/core/partitioning/test_conditionals.cpp index 6a9b675086..395a99be1b 100644 --- a/tests/core/partitioning/test_conditionals.cpp +++ b/tests/core/partitioning/test_conditionals.cpp @@ -35,12 +35,13 @@ TEST(Partitioning, FallbackOnConditionalsCorrectly) { auto g = mod.get_method("forward").graph(); torch_tensorrt::core::CompileSpec cfg(inputs); cfg.partitioning_info.enabled = true; + cfg.partitioning_info.forced_fallback_operators.push_back("aten::log_sigmoid"); torch::jit::script::Module new_mod = torch_tensorrt::core::CompileGraph(mod, cfg); auto new_g = new_mod.get_method("forward").graph(); auto conditional_engines_count = count_trt_engines_in_conditionals(new_g); - ASSERT_TRUE(conditional_engines_count == 1); + ASSERT_TRUE(conditional_engines_count == 2); } TEST(Partitioning, FallbackInplaceOPInConditionalsCorrectly) {