Skip to content

Commit

Permalink
fix: Minor bugfix in partitioning test
Browse files Browse the repository at this point in the history
- Partitioning test incorrectly expected 1 conditional engine, but got 2
since `log_sigmoid` operator is not currently supported
  • Loading branch information
gs-olive committed Feb 23, 2023
1 parent a32e254 commit 746a9d6
Show file tree
Hide file tree
Showing 2 changed files with 16 additions and 6 deletions.
19 changes: 14 additions & 5 deletions tests/core/lowering/test_conv1d_pass.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,8 @@
#include "tests/util/util.h"
#include "torch/csrc/jit/ir/irparser.h"
#include "torch/csrc/jit/ir/subgraph_matcher.h"
#include "torch/csrc/jit/passes/canonicalize.h"
#include "torch/csrc/jit/passes/constant_pooling.h"

TEST(LoweringPasses, Conv1dCorrectly) {
const auto source_graph = R"IR(
Expand Down Expand Up @@ -119,7 +121,7 @@ TEST(LoweringPasses, ConvTransposed1dCorrectly) {
}

TEST(LoweringPasses, Conv1dWithConditionalLowersCorrectly) {
const auto source_graph = R"IR(
std::string source_graph = R"IR(
graph(%0 : Tensor,
%1 : Float(4, 3, 3, strides=[9, 3, 1]),
%2 : Float(3)):
Expand All @@ -142,21 +144,21 @@ TEST(LoweringPasses, Conv1dWithConditionalLowersCorrectly) {
-> (%res)
return (%12))IR";

const auto target_graph = R"IR(
std::string target_graph = R"IR(
graph(%0 : Tensor,
%1 : Float(4, 3, 3, strides=[9, 3, 1]),
%2 : Float(3)):
%3 : bool = prim::Constant[value=0]()
%4 : int = prim::Constant[value=0]()
%5 : int = prim::Constant[value=1]()
%true : bool = prim::Constant[value=1]()
%3 : bool = prim::Constant[value=0]()
%output_padding : int[] = prim::Constant[value=[0]]()
%6 : int = prim::Constant[value=1]()
%stride : int[] = prim::ListConstruct(%6)
%padding : int[] = prim::ListConstruct(%4)
%dilation : int[] = prim::ListConstruct(%5)
%output_padding : int[] = prim::Constant[value=[0]]()
# Add intentionally-invalid weight tensor to ensure prim::If blocks are respected
%true : bool = prim::Constant[value=1]()
%invalid_weight : Tensor = aten::transpose(%0, %4, %5)
%12 : Tensor = prim::If(%true)
block0():
Expand All @@ -172,9 +174,16 @@ TEST(LoweringPasses, Conv1dWithConditionalLowersCorrectly) {
auto sg = std::make_shared<torch::jit::Graph>();
torch::jit::parseIR(source_graph, &*sg);
torch_tensorrt::core::lowering::passes::Conv1DToConvolution(sg);
torch::jit::ConstantPooling(sg);
sg = torch::jit::Canonicalize(sg, false);

auto tg = std::make_shared<torch::jit::Graph>();
torch::jit::parseIR(target_graph, &*tg);
torch::jit::ConstantPooling(tg);
tg = torch::jit::Canonicalize(tg, false);

// Validate identical graphs after pooling constants and canonicalizing
ASSERT_TRUE((tg->toString() == sg->toString()));

auto in = at::randint(1, 2, {1, 3, 3}, {at::kCUDA});
auto w = at::randint(1, 2, {4, 3, 3}, {at::kCUDA});
Expand Down
3 changes: 2 additions & 1 deletion tests/core/partitioning/test_conditionals.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -35,12 +35,13 @@ TEST(Partitioning, FallbackOnConditionalsCorrectly) {
auto g = mod.get_method("forward").graph();
torch_tensorrt::core::CompileSpec cfg(inputs);
cfg.partitioning_info.enabled = true;
cfg.partitioning_info.forced_fallback_operators.push_back("aten::log_sigmoid");
torch::jit::script::Module new_mod = torch_tensorrt::core::CompileGraph(mod, cfg);
auto new_g = new_mod.get_method("forward").graph();

auto conditional_engines_count = count_trt_engines_in_conditionals(new_g);

ASSERT_TRUE(conditional_engines_count == 1);
ASSERT_TRUE(conditional_engines_count == 2);
}

TEST(Partitioning, FallbackInplaceOPInConditionalsCorrectly) {
Expand Down

0 comments on commit 746a9d6

Please sign in to comment.