Skip to content

Commit

Permalink
fix: Bugfix in convNd_to_convolution lowering pass
Browse files Browse the repository at this point in the history
- Lowering pass did not respect `prim::If` block boundaries
- Refactor convNd implementation to use more precise guard-insert
paradigm instead of subgraph rewriting
- Write general function to apply for all convolution replacements
- When replacing a subgraph that occurs within an "If" block,
the rewriter places the actual logic of the code outside of the block,
so the rewrite makes the code execute both the "if" and the "else" path
regardless of what the condition is
- Add a test case to validate refactoring on conv1d
  • Loading branch information
gs-olive committed Feb 22, 2023
1 parent deda87b commit a32e254
Show file tree
Hide file tree
Showing 2 changed files with 123 additions and 35 deletions.
84 changes: 49 additions & 35 deletions core/lowering/passes/convNd_to_convolution.cpp
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
#include <torch/csrc/jit/passes/subgraph_rewrite.h>
#include "torch/csrc/jit/ir/irparser.h"

#include "core/util/prelude.h"

Expand All @@ -7,78 +8,91 @@ namespace core {
namespace lowering {
namespace passes {

void Conv1DToConvolution(std::shared_ptr<torch::jit::Graph>& graph) {
std::string conv1d_pattern = R"IR(
graph(%x, %w, %b, %s, %p, %d, %g):
%4 : Tensor = aten::conv1d(%x, %w, %b, %s, %p, %d, %g)
return (%4))IR";
void replaceConv(
torch::jit::Block* block,
const std::string& node_kind,
const std::string& unwrapped_conv,
const size_t num_input_args) {
// Iterate through nodes in block, seaching for aten::conv*
for (auto it = block->nodes().begin(); it != block->nodes().end(); it++) {
auto n = *it;

// Recursively explore nested blocks, such as those arising from prim::If
for (auto nested_block : n->blocks()) {
replaceConv(nested_block, node_kind, unwrapped_conv, num_input_args);
}

// If node matches desired kind and number of input arguments, replace it
if ((n->kind().toQualString() == node_kind) && (n->inputs().size() == num_input_args)) {
// Establish insert point within block
torch::jit::WithInsertPoint guard(*it);

// Initialize new fused subgraph from IR code provided
auto fused_g = std::make_shared<torch::jit::Graph>();
torch::jit::parseIR(unwrapped_conv, fused_g.get());

// Insert subgraph in place of aten::conv*, replacing inputs and outputs accordingly
torch::jit::Value* new_output = insertGraph(*it->owningGraph(), *fused_g, it->inputs()).at(0);
new_output->setType(it->output()->type());
it->output()->replaceAllUsesWith(new_output);
it.destroyCurrent();
}
}
}

std::string convolution_pattern = R"IR(
void Conv1DToConvolution(std::shared_ptr<torch::jit::Graph>& graph) {
const std::string conv1d_node_kind = "aten::conv1d";
const std::string convolution_pattern = R"IR(
graph(%x, %w, %b, %s, %p, %d, %g):
%1 : bool = prim::Constant[value=0]()
%2 : int[] = prim::Constant[value=[0]]()
%4 : Tensor = aten::_convolution(%x, %w, %b, %s, %p, %d, %1, %2, %g, %1, %1, %1, %1)
return (%4))IR";

torch::jit::SubgraphRewriter map_conv1d_to_convolution;
map_conv1d_to_convolution.RegisterRewritePattern(conv1d_pattern, convolution_pattern);
map_conv1d_to_convolution.runOnGraph(graph);
// Schema is aten::conv1d(%x, %w, %b, %s, %p, %d, %g) --> 7 inputs
replaceConv(graph->block(), conv1d_node_kind, convolution_pattern, 7);
LOG_GRAPH("Post map conv1d -> _convolution: " << *graph);
}

void ConvTransposed1DToConvolution(std::shared_ptr<torch::jit::Graph>& graph) {
std::string conv_transpose1d_pattern = R"IR(
graph(%x, %w, %b, %s, %p, %o, %g, %d):
%4 : Tensor = aten::conv_transpose1d(%x, %w, %b, %s, %p, %o, %g, %d)
return (%4))IR";
std::string convolution_pattern = R"IR(
const std::string conv_transpose1d_node_kind = "aten::conv_transpose1d";
const std::string convolution_pattern = R"IR(
graph(%x, %w, %b, %s, %p, %o, %g, %d):
%1 : bool = prim::Constant[value=1]()
%2 : bool = prim::Constant[value=1]()
%4 : Tensor = aten::_convolution(%x, %w, %b, %s, %p, %d, %1, %o, %g, %2, %2, %2, %2)
return (%4))IR";

torch::jit::SubgraphRewriter map_conv_transpose1d_to_convolution;
map_conv_transpose1d_to_convolution.RegisterRewritePattern(conv_transpose1d_pattern, convolution_pattern);
map_conv_transpose1d_to_convolution.runOnGraph(graph);
// Schema is aten::conv_transpose1d(%x, %w, %b, %s, %p, %o, %g, %d) --> 8 inputs
replaceConv(graph->block(), conv_transpose1d_node_kind, convolution_pattern, 8);
LOG_GRAPH("Post map conv_transpose1d -> _convolution: " << *graph);
}

void Conv2DToConvolution(std::shared_ptr<torch::jit::Graph>& graph) {
std::string conv2d_pattern = R"IR(
graph(%x, %w, %b, %s, %p, %d, %g):
%4 : Tensor = aten::conv2d(%x, %w, %b, %s, %p, %d, %g)
return (%4))IR";
std::string convolution_pattern = R"IR(
const std::string conv2d_node_kind = "aten::conv2d";
const std::string convolution_pattern = R"IR(
graph(%x, %w, %b, %s, %p, %d, %g):
%1 : bool = prim::Constant[value=0]()
%2 : int[] = prim::Constant[value=[0, 0]]()
%4 : Tensor = aten::_convolution(%x, %w, %b, %s, %p, %d, %1, %2, %g, %1, %1, %1, %1)
return (%4))IR";

// replace matmul + add pattern to linear
torch::jit::SubgraphRewriter map_conv2d_to_convolution;
map_conv2d_to_convolution.RegisterRewritePattern(conv2d_pattern, convolution_pattern);
map_conv2d_to_convolution.runOnGraph(graph);
// Schema is aten::conv2d(%x, %w, %b, %s, %p, %d, %g) --> 7 inputs
replaceConv(graph->block(), conv2d_node_kind, convolution_pattern, 7);
LOG_GRAPH("Post map conv2d -> _convolution: " << *graph);
}

void Conv3DToConvolution(std::shared_ptr<torch::jit::Graph>& graph) {
std::string conv3d_pattern = R"IR(
graph(%x, %w, %b, %s, %p, %d, %g):
%4 : Tensor = aten::conv3d(%x, %w, %b, %s, %p, %d, %g)
return (%4))IR";
std::string convolution_pattern = R"IR(
const std::string conv3d_node_kind = "aten::conv3d";
const std::string convolution_pattern = R"IR(
graph(%x, %w, %b, %s, %p, %d, %g):
%1 : bool = prim::Constant[value=0]()
%2 : int[] = prim::Constant[value=[0, 0, 0]]()
%4 : Tensor = aten::_convolution(%x, %w, %b, %s, %p, %d, %1, %2, %g, %1, %1, %1, %1)
return (%4))IR";

// replace matmul + add pattern to linear
torch::jit::SubgraphRewriter map_conv3d_to_convolution;
map_conv3d_to_convolution.RegisterRewritePattern(conv3d_pattern, convolution_pattern);
map_conv3d_to_convolution.runOnGraph(graph);
// Schema is aten::conv3d(%x, %w, %b, %s, %p, %d, %g) --> 7 inputs
replaceConv(graph->block(), conv3d_node_kind, convolution_pattern, 7);
LOG_GRAPH("Post map conv3d -> _convolution: " << *graph);
}

Expand Down
74 changes: 74 additions & 0 deletions tests/core/lowering/test_conv1d_pass.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -117,3 +117,77 @@ TEST(LoweringPasses, ConvTransposed1dCorrectly) {

ASSERT_TRUE(torch_tensorrt::tests::util::almostEqual(trt_results_sg[0], trt_results_tg[0], 2e-6));
}

TEST(LoweringPasses, Conv1dWithConditionalLowersCorrectly) {
const auto source_graph = R"IR(
graph(%0 : Tensor,
%1 : Float(4, 3, 3, strides=[9, 3, 1]),
%2 : Float(3)):
%4 : int = prim::Constant[value=0]()
%5 : int = prim::Constant[value=1]()
%6 : int = prim::Constant[value=1]()
%stride : int[] = prim::ListConstruct(%6)
%padding : int[] = prim::ListConstruct(%4)
%dilation : int[] = prim::ListConstruct(%5)
# Add intentionally-invalid weight tensor to ensure prim::If blocks are respected
%true : bool = prim::Constant[value=1]()
%invalid_weight : Tensor = aten::transpose(%0, %4, %5)
%12 : Tensor = prim::If(%true)
block0():
%res: Tensor = aten::conv1d(%0, %1, %2, %stride, %padding, %dilation, %6)
-> (%res)
block1():
%res: Tensor = aten::conv1d(%invalid_weight, %1, %2, %stride, %padding, %dilation, %6)
-> (%res)
return (%12))IR";

const auto target_graph = R"IR(
graph(%0 : Tensor,
%1 : Float(4, 3, 3, strides=[9, 3, 1]),
%2 : Float(3)):
%3 : bool = prim::Constant[value=0]()
%4 : int = prim::Constant[value=0]()
%5 : int = prim::Constant[value=1]()
%6 : int = prim::Constant[value=1]()
%stride : int[] = prim::ListConstruct(%6)
%padding : int[] = prim::ListConstruct(%4)
%dilation : int[] = prim::ListConstruct(%5)
%output_padding : int[] = prim::Constant[value=[0]]()
# Add intentionally-invalid weight tensor to ensure prim::If blocks are respected
%true : bool = prim::Constant[value=1]()
%invalid_weight : Tensor = aten::transpose(%0, %4, %5)
%12 : Tensor = prim::If(%true)
block0():
%res: Tensor = aten::_convolution(%0, %1, %2, %stride, %padding, %dilation, %3, %output_padding, %6, %3, %3, %3, %3)
-> (%res)
block1():
%res: Tensor = aten::_convolution(%invalid_weight, %1, %2, %stride, %padding, %dilation, %3, %output_padding, %6, %3, %3, %3, %3)
-> (%res)
return (%12))IR";

torch_tensorrt::core::util::logging::get_logger().set_reportable_log_level(
torch_tensorrt::core::util::logging::LogLevel::kGRAPH);
auto sg = std::make_shared<torch::jit::Graph>();
torch::jit::parseIR(source_graph, &*sg);
torch_tensorrt::core::lowering::passes::Conv1DToConvolution(sg);

auto tg = std::make_shared<torch::jit::Graph>();
torch::jit::parseIR(target_graph, &*tg);

auto in = at::randint(1, 2, {1, 3, 3}, {at::kCUDA});
auto w = at::randint(1, 2, {4, 3, 3}, {at::kCUDA});
auto b = at::randint(1, 10, {4}, {at::kCUDA});

auto trt_in = at::clone(in);
auto trt_w = at::clone(w);
auto trt_b = at::clone(b);
auto params = torch_tensorrt::core::ir::get_static_params(sg->inputs(), {trt_w, trt_b});
auto trt_results_sg = torch_tensorrt::tests::util::RunGraphEngine(sg, params, {trt_in});

params = torch_tensorrt::core::ir::get_static_params(tg->inputs(), {trt_w, trt_b});
auto trt_results_tg = torch_tensorrt::tests::util::RunGraphEngine(tg, params, {trt_in});

ASSERT_TRUE(torch_tensorrt::tests::util::almostEqual(trt_results_sg[0], trt_results_tg[0], 2e-6));
}

0 comments on commit a32e254

Please sign in to comment.