diff --git a/core/lowering/passes/BUILD b/core/lowering/passes/BUILD index 7ff6651a27..de5d6b21a9 100644 --- a/core/lowering/passes/BUILD +++ b/core/lowering/passes/BUILD @@ -10,9 +10,7 @@ config_setting( cc_library( name = "passes", srcs = [ - "conv1d_to_convolution.cpp", - "conv2d_to_convolution.cpp", - "conv3d_to_convolution.cpp", + "convNd_to_convolution.cpp", "exception_elimination.cpp", "fuse_addmm_branches.cpp", "linear_to_addmm.cpp", diff --git a/core/lowering/passes/conv2d_to_convolution.cpp b/core/lowering/passes/conv2d_to_convolution.cpp deleted file mode 100644 index b1b3643624..0000000000 --- a/core/lowering/passes/conv2d_to_convolution.cpp +++ /dev/null @@ -1,33 +0,0 @@ -#include - -#include "core/util/prelude.h" - -namespace torch_tensorrt { -namespace core { -namespace lowering { -namespace passes { - -void Conv2DToConvolution(std::shared_ptr& graph) { - std::string conv2d_pattern = R"IR( - graph(%x, %w, %b, %s, %p, %d, %g): - %4 : Tensor = aten::conv2d(%x, %w, %b, %s, %p, %d, %g) - return (%4))IR"; - std::string convolution_pattern = R"IR( - graph(%x, %w, %b, %s, %p, %d, %g): - %1 : bool = prim::Constant[value=0]() - %2 : int[] = prim::Constant[value=[0, 0]]() - %4 : Tensor = aten::_convolution(%x, %w, %b, %s, %p, %d, %1, %2, %g, %1, %1, %1, %1) - return (%4))IR"; - ; - - // replace matmul + add pattern to linear - torch::jit::SubgraphRewriter map_conv2d_to_convolution; - map_conv2d_to_convolution.RegisterRewritePattern(conv2d_pattern, convolution_pattern); - map_conv2d_to_convolution.runOnGraph(graph); - LOG_GRAPH("Post map conv2d -> _convolution: " << *graph); -} - -} // namespace passes -} // namespace lowering -} // namespace core -} // namespace torch_tensorrt \ No newline at end of file diff --git a/core/lowering/passes/conv3d_to_convolution.cpp b/core/lowering/passes/conv3d_to_convolution.cpp deleted file mode 100644 index 69158a3364..0000000000 --- a/core/lowering/passes/conv3d_to_convolution.cpp +++ /dev/null @@ -1,33 +0,0 @@ -#include - -#include "core/util/prelude.h" - -namespace torch_tensorrt { -namespace core { -namespace lowering { -namespace passes { - -void Conv3DToConvolution(std::shared_ptr& graph) { - std::string conv3d_pattern = R"IR( - graph(%x, %w, %b, %s, %p, %d, %g): - %4 : Tensor = aten::conv3d(%x, %w, %b, %s, %p, %d, %g) - return (%4))IR"; - std::string convolution_pattern = R"IR( - graph(%x, %w, %b, %s, %p, %d, %g): - %1 : bool = prim::Constant[value=0]() - %2 : int[] = prim::Constant[value=[0, 0, 0]]() - %4 : Tensor = aten::_convolution(%x, %w, %b, %s, %p, %d, %1, %2, %g, %1, %1, %1, %1) - return (%4))IR"; - ; - - // replace matmul + add pattern to linear - torch::jit::SubgraphRewriter map_conv3d_to_convolution; - map_conv3d_to_convolution.RegisterRewritePattern(conv3d_pattern, convolution_pattern); - map_conv3d_to_convolution.runOnGraph(graph); - LOG_GRAPH("Post map conv3d -> _convolution: " << *graph); -} - -} // namespace passes -} // namespace lowering -} // namespace core -} // namespace torch_tensorrt \ No newline at end of file diff --git a/core/lowering/passes/conv1d_to_convolution.cpp b/core/lowering/passes/convNd_to_convolution.cpp similarity index 51% rename from core/lowering/passes/conv1d_to_convolution.cpp rename to core/lowering/passes/convNd_to_convolution.cpp index bcd056e791..e5c4578a39 100644 --- a/core/lowering/passes/conv1d_to_convolution.cpp +++ b/core/lowering/passes/convNd_to_convolution.cpp @@ -2,7 +2,7 @@ #include "core/util/prelude.h" -namespace trtorch { +namespace torch_tensorrt { namespace core { namespace lowering { namespace passes { @@ -12,6 +12,7 @@ void Conv1DToConvolution(std::shared_ptr& graph) { graph(%x, %w, %b, %s, %p, %d, %g): %4 : Tensor = aten::conv1d(%x, %w, %b, %s, %p, %d, %g) return (%4))IR"; + std::string convolution_pattern = R"IR( graph(%x, %w, %b, %s, %p, %d, %g): %1 : bool = prim::Constant[value=0]() @@ -43,7 +44,45 @@ void ConvTransposed1DToConvolution(std::shared_ptr& graph) { LOG_GRAPH("Post map conv_transpose1d -> _convolution: " << *graph); } +void Conv2DToConvolution(std::shared_ptr& graph) { + std::string conv2d_pattern = R"IR( + graph(%x, %w, %b, %s, %p, %d, %g): + %4 : Tensor = aten::conv2d(%x, %w, %b, %s, %p, %d, %g) + return (%4))IR"; + std::string convolution_pattern = R"IR( + graph(%x, %w, %b, %s, %p, %d, %g): + %1 : bool = prim::Constant[value=0]() + %2 : int[] = prim::Constant[value=[0, 0]]() + %4 : Tensor = aten::_convolution(%x, %w, %b, %s, %p, %d, %1, %2, %g, %1, %1, %1, %1) + return (%4))IR"; + + // replace matmul + add pattern to linear + torch::jit::SubgraphRewriter map_conv2d_to_convolution; + map_conv2d_to_convolution.RegisterRewritePattern(conv2d_pattern, convolution_pattern); + map_conv2d_to_convolution.runOnGraph(graph); + LOG_GRAPH("Post map conv2d -> _convolution: " << *graph); +} + +void Conv3DToConvolution(std::shared_ptr& graph) { + std::string conv3d_pattern = R"IR( + graph(%x, %w, %b, %s, %p, %d, %g): + %4 : Tensor = aten::conv3d(%x, %w, %b, %s, %p, %d, %g) + return (%4))IR"; + std::string convolution_pattern = R"IR( + graph(%x, %w, %b, %s, %p, %d, %g): + %1 : bool = prim::Constant[value=0]() + %2 : int[] = prim::Constant[value=[0, 0, 0]]() + %4 : Tensor = aten::_convolution(%x, %w, %b, %s, %p, %d, %1, %2, %g, %1, %1, %1, %1) + return (%4))IR"; + + // replace matmul + add pattern to linear + torch::jit::SubgraphRewriter map_conv3d_to_convolution; + map_conv3d_to_convolution.RegisterRewritePattern(conv3d_pattern, convolution_pattern); + map_conv3d_to_convolution.runOnGraph(graph); + LOG_GRAPH("Post map conv3d -> _convolution: " << *graph); +} + } // namespace passes } // namespace lowering } // namespace core -} // namespace trtorch \ No newline at end of file +} // namespace torch_tensorrt \ No newline at end of file diff --git a/tests/core/lowering/test_conv1d_pass.cpp b/tests/core/lowering/test_conv1d_pass.cpp index afd4389049..7a86152937 100644 --- a/tests/core/lowering/test_conv1d_pass.cpp +++ b/tests/core/lowering/test_conv1d_pass.cpp @@ -35,10 +35,10 @@ TEST(LoweringPasses, Conv1dCorrectly) { %12 : Tensor = aten::_convolution(%0, %1, %2, %stride, %padding, %dilation, %3, %output_padding, %6, %3, %3, %3, %3) return (%12))IR"; - trtorch::core::util::logging::get_logger().set_reportable_log_level(trtorch::core::util::logging::LogLevel::kGRAPH); + torch_tensorrt::core::util::logging::get_logger().set_reportable_log_level(torch_tensorrt::core::util::logging::LogLevel::kGRAPH); auto sg = std::make_shared(); torch::jit::parseIR(source_graph, &*sg); - trtorch::core::lowering::passes::Conv1DToConvolution(sg); + torch_tensorrt::core::lowering::passes::Conv1DToConvolution(sg); auto tg = std::make_shared(); torch::jit::parseIR(target_graph, &*tg); @@ -50,13 +50,13 @@ TEST(LoweringPasses, Conv1dCorrectly) { auto trt_in = at::clone(in); auto trt_w = at::clone(w); auto trt_b = at::clone(b); - auto params = trtorch::core::conversion::get_named_params(sg->inputs(), {trt_w, trt_b}); - auto trt_results_sg = trtorch::tests::util::RunGraphEngine(sg, params, {trt_in}); + auto params = torch_tensorrt::core::ir::get_static_params(sg->inputs(), {trt_w, trt_b}); + auto trt_results_sg = torch_tensorrt::tests::util::RunGraphEngine(sg, params, {trt_in}); - params = trtorch::core::conversion::get_named_params(tg->inputs(), {trt_w, trt_b}); - auto trt_results_tg = trtorch::tests::util::RunGraphEngine(tg, params, {trt_in}); + params = torch_tensorrt::core::ir::get_static_params(tg->inputs(), {trt_w, trt_b}); + auto trt_results_tg = torch_tensorrt::tests::util::RunGraphEngine(tg, params, {trt_in}); - ASSERT_TRUE(trtorch::tests::util::almostEqual(trt_results_sg[0], trt_results_tg[0], 2e-6)); + ASSERT_TRUE(torch_tensorrt::tests::util::almostEqual(trt_results_sg[0], trt_results_tg[0], 2e-6)); } TEST(LoweringPasses, ConvTransposed1dCorrectly) { @@ -92,10 +92,10 @@ TEST(LoweringPasses, ConvTransposed1dCorrectly) { %12 : Tensor = aten::_convolution(%0, %1, %2, %stride, %padding, %dilation, %8, %output_padding, %5, %7, %7, %7, %7) return (%12))IR"; - trtorch::core::util::logging::get_logger().set_reportable_log_level(trtorch::core::util::logging::LogLevel::kGRAPH); + torch_tensorrt::core::util::logging::get_logger().set_reportable_log_level(torch_tensorrt::core::util::logging::LogLevel::kGRAPH); auto sg = std::make_shared(); torch::jit::parseIR(source_graph, &*sg); - trtorch::core::lowering::passes::ConvTransposed1DToConvolution(sg); + torch_tensorrt::core::lowering::passes::ConvTransposed1DToConvolution(sg); auto tg = std::make_shared(); torch::jit::parseIR(target_graph, &*tg); @@ -107,11 +107,11 @@ TEST(LoweringPasses, ConvTransposed1dCorrectly) { auto trt_in = at::clone(in); auto trt_w = at::clone(w); auto trt_b = at::clone(b); - auto params = trtorch::core::conversion::get_named_params(sg->inputs(), {trt_w, trt_b}); - auto trt_results_sg = trtorch::tests::util::RunGraphEngine(sg, params, {trt_in}); + auto params = torch_tensorrt::core::ir::get_static_params(sg->inputs(), {trt_w, trt_b}); + auto trt_results_sg = torch_tensorrt::tests::util::RunGraphEngine(sg, params, {trt_in}); - params = trtorch::core::conversion::get_named_params(tg->inputs(), {trt_w, trt_b}); - auto trt_results_tg = trtorch::tests::util::RunGraphEngine(tg, params, {trt_in}); + params = torch_tensorrt::core::ir::get_static_params(tg->inputs(), {trt_w, trt_b}); + auto trt_results_tg = torch_tensorrt::tests::util::RunGraphEngine(tg, params, {trt_in}); - ASSERT_TRUE(trtorch::tests::util::almostEqual(trt_results_sg[0], trt_results_tg[0], 2e-6)); + ASSERT_TRUE(torch_tensorrt::tests::util::almostEqual(trt_results_sg[0], trt_results_tg[0], 2e-6)); } \ No newline at end of file