diff --git a/tests/core/partitioning/test_conditionals.cpp b/tests/core/partitioning/test_conditionals.cpp index 6a9b675086..395a99be1b 100644 --- a/tests/core/partitioning/test_conditionals.cpp +++ b/tests/core/partitioning/test_conditionals.cpp @@ -35,12 +35,13 @@ TEST(Partitioning, FallbackOnConditionalsCorrectly) { auto g = mod.get_method("forward").graph(); torch_tensorrt::core::CompileSpec cfg(inputs); cfg.partitioning_info.enabled = true; + cfg.partitioning_info.forced_fallback_operators.push_back("aten::log_sigmoid"); torch::jit::script::Module new_mod = torch_tensorrt::core::CompileGraph(mod, cfg); auto new_g = new_mod.get_method("forward").graph(); auto conditional_engines_count = count_trt_engines_in_conditionals(new_g); - ASSERT_TRUE(conditional_engines_count == 1); + ASSERT_TRUE(conditional_engines_count == 2); } TEST(Partitioning, FallbackInplaceOPInConditionalsCorrectly) {