diff --git a/core/lowering/passes/linear_to_addmm.cpp b/core/lowering/passes/linear_to_addmm.cpp index e549aea39e..1944284a5f 100644 --- a/core/lowering/passes/linear_to_addmm.cpp +++ b/core/lowering/passes/linear_to_addmm.cpp @@ -3,6 +3,7 @@ #include "core/util/prelude.h" #include "torch/csrc/jit/api/function_impl.h" #include "torch/csrc/jit/ir/alias_analysis.h" +#include "torch/csrc/jit/ir/irparser.h" #include "torch/csrc/jit/jit_log.h" #include "torch/csrc/jit/passes/constant_propagation.h" #include "torch/csrc/jit/passes/dead_code_elimination.h" @@ -16,26 +17,58 @@ namespace core { namespace lowering { namespace passes { -void replaceLinearWithBiasNonePattern(std::shared_ptr graph) { +void replaceLinear(torch::jit::Block* block) { // Define the decomposition function for aten::linear for the case where bias (mat2) is None. static torch::jit::CompilationUnit decompose_funcs(R"SCRIPT( def linear(self: Tensor, mat1: Tensor, mat2: Tensor): return torch.matmul(self, mat1.t()) )SCRIPT"); - // Iterate through nodes and search for aten::linear nodes where bias is not a Tensor (includes bias=None case) - auto block = graph->block(); + // Define graph format for aten::linear with Tensor-type bias + std::string fused_linear = R"IR( + graph(%input, %weight, %bias): + %1: int = prim::Constant[value=1]() + %weight = aten::t(%weight) + %mm: Tensor = aten::matmul(%input, %weight) + %b_f: Tensor = trt::const(%bias) + %out: Tensor = aten::add(%b_f, %mm, %1) + return (%out))IR"; + + // Iterate through nodes in block, seaching for aten::linear for (auto it = block->nodes().begin(); it != block->nodes().end(); it++) { auto n = *it; - if (n->kind().toQualString() == std::string("aten::linear")) { + + // Recursively explore nested blocks, such as those arising from prim::If + for (auto block : n->blocks()) { + replaceLinear(block); + } + + if ((n->kind().toQualString() == std::string("aten::linear")) && (n->inputs().size() >= 3)) { auto input_values = n->inputs(); - // input_values[2] is the bias. If none, replace it with the decomposed linear graph. + + // input_values[2] is the bias + // If Tensor, replace with fused-bias decomposed graph + // If none, replace it with the decomposed linear graph. if (input_values[2]->type()->isSubtypeOf(c10::TensorType::get())) { - continue; + torch::jit::WithInsertPoint guard(*it); + + // Initialize new fused subgraph from IR code above + auto fused_g = std::make_shared(); + torch::jit::parseIR(fused_linear, fused_g.get()); + + // Insert subgraph in place of aten::linear, replacing inputs and outputs accordingly + torch::jit::Value* new_output = insertGraph(*it->owningGraph(), *fused_g, it->inputs()).at(0); + new_output->setType(it->output()->type()); + it->output()->replaceAllUsesWith(new_output); + it.destroyCurrent(); } else { torch::jit::WithInsertPoint guard(*it); + + // Initialized decomposed graph without bias term std::shared_ptr d_graph = toGraphFunction(decompose_funcs.get_function("linear")).graph(); torch::jit::Value* new_output = insertGraph(*it->owningGraph(), *d_graph, it->inputs()).at(0); + + // Insert function in place of aten::linear, replacing inputs and outputs accordingly new_output->setType(it->output()->type()); it->output()->replaceAllUsesWith(new_output); it.destroyCurrent(); @@ -45,27 +78,8 @@ void replaceLinearWithBiasNonePattern(std::shared_ptr graph) } void LinearToAddMM(std::shared_ptr& graph) { - // TensorRT implicitly adds a flatten layer infront of FC layers if necessary - std::string flatten_linear_pattern = R"IR( - graph(%input, %weight, %bias): - %res = aten::linear(%input, %weight, %bias) - return (%res))IR"; - - std::string fused_linear = R"IR( - graph(%input, %weight_t, %bias): - %1: int = prim::Constant[value=1]() - %weight = aten::t(%weight_t) - %mm: Tensor = aten::matmul(%input, %weight) - %b_f: Tensor = trt::const(%bias) - %out: Tensor = aten::add(%b_f, %mm, %1) - return (%out))IR"; - - // First find and replace aten::linear nodes with non-tensor bias values. - replaceLinearWithBiasNonePattern(graph); - - torch::jit::SubgraphRewriter flatten_linear_to_linear; - flatten_linear_to_linear.RegisterRewritePattern(flatten_linear_pattern, fused_linear); - flatten_linear_to_linear.runOnGraph(graph); + // Recursively find and replace all instances of aten::linear with the corresponding decomposed form + replaceLinear(graph->block()); } } // namespace passes diff --git a/tests/core/lowering/test_linear_to_addmm.cpp b/tests/core/lowering/test_linear_to_addmm.cpp index 2446916c3d..f1ff35ee31 100644 --- a/tests/core/lowering/test_linear_to_addmm.cpp +++ b/tests/core/lowering/test_linear_to_addmm.cpp @@ -57,3 +57,81 @@ TEST(LoweringPasses, LinearToAddMMBiasNone) { ASSERT_TRUE(!torch::jit::findPatternMatches(*tg, *sg).empty()); } + +TEST(LoweringPasses, LinearToAddMMBiasNoneGraphRun) { + std::string source_graph = R"IR( + graph(%input, %weight): + %biasNone : None = prim::Constant() + %true : bool = prim::Constant[value=1]() + %invalid_weight : Tensor = aten::t(%weight) + %4 : Tensor = prim::If(%true) + block0(): + %res = aten::linear(%input, %weight, %biasNone) + -> (%res) + block1(): + %res = aten::linear(%input, %invalid_weight, %biasNone) + -> (%res) + return (%4))IR"; + + // This regression test case ensures the Linear-to-AddMM lowering pass satisfies two constraints for non-Tensor bias: + // 1. It recursively resolves sub-blocks within the node, replacing sub-blocks to be converted as well + // 2. It does not pre-evaluate branches of the block which may have invalid operations + + auto g = std::make_shared(); + torch::jit::parseIR(source_graph, g.get()); + + auto in_0 = at::rand({8, 7}, {at::kCUDA}); + auto in_1 = at::rand({8, 7}, {at::kCUDA}); + + auto params = torch_tensorrt::core::ir::get_static_params(g->inputs(), {}); + auto jit_results = torch_tensorrt::tests::util::RunGraph(g, params, {in_0, in_1}); + + torch_tensorrt::core::lowering::passes::LinearToAddMM(g); + + LOG_DEBUG(g); + + params = torch_tensorrt::core::ir::get_static_params(g->inputs(), {}); + auto trt_results = torch_tensorrt::tests::util::RunGraphEngine(g, params, {in_0, in_1}); + + ASSERT_TRUE( + torch_tensorrt::tests::util::almostEqual(jit_results[0], trt_results[0].reshape_as(jit_results[0]), 2e-6)); +} + +TEST(LoweringPasses, LinearToAddMMBiasGraphRun) { + std::string source_graph = R"IR( + graph(%input, %weight, %bias): + %true : bool = prim::Constant[value=1]() + %invalid_weight : Tensor = aten::t(%weight) + %4 : Tensor = prim::If(%true) + block0(): + %res = aten::linear(%input, %weight, %bias) + -> (%res) + block1(): + %res = aten::linear(%input, %invalid_weight, %bias) + -> (%res) + return (%4))IR"; + + // This regression test case ensures the Linear-to-AddMM lowering pass satisfies two constraints for Tensor bias: + // 1. It recursively resolves sub-blocks within the node, replacing sub-blocks to be converted as well + // 2. It does not pre-evaluate branches of the block which may have invalid operations + + auto g = std::make_shared(); + torch::jit::parseIR(source_graph, g.get()); + + auto in_0 = at::rand({8, 7}, {at::kCUDA}); + auto in_1 = at::rand({8, 7}, {at::kCUDA}); + auto in_2 = at::rand({8, 8}, {at::kCUDA}); + + auto params = torch_tensorrt::core::ir::get_static_params(g->inputs(), {}); + auto jit_results = torch_tensorrt::tests::util::RunGraph(g, params, {in_0, in_1, in_2}); + + torch_tensorrt::core::lowering::passes::LinearToAddMM(g); + + LOG_DEBUG(g); + + params = torch_tensorrt::core::ir::get_static_params(g->inputs(), {}); + auto trt_results = torch_tensorrt::tests::util::RunGraphEngine(g, params, {in_0, in_1, in_2}); + + ASSERT_TRUE( + torch_tensorrt::tests::util::almostEqual(jit_results[0], trt_results[0].reshape_as(jit_results[0]), 2e-6)); +}