Skip to content

Commit

Permalink
fix: Bugfix in Linear-to-AddMM Fusion Lowering Pass (#1619)
Browse files Browse the repository at this point in the history
  • Loading branch information
gs-olive authored Feb 1, 2023
1 parent 779cdea commit 3e422f5
Show file tree
Hide file tree
Showing 2 changed files with 119 additions and 27 deletions.
68 changes: 41 additions & 27 deletions core/lowering/passes/linear_to_addmm.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
#include "core/util/prelude.h"
#include "torch/csrc/jit/api/function_impl.h"
#include "torch/csrc/jit/ir/alias_analysis.h"
#include "torch/csrc/jit/ir/irparser.h"
#include "torch/csrc/jit/jit_log.h"
#include "torch/csrc/jit/passes/constant_propagation.h"
#include "torch/csrc/jit/passes/dead_code_elimination.h"
Expand All @@ -16,26 +17,58 @@ namespace core {
namespace lowering {
namespace passes {

void replaceLinearWithBiasNonePattern(std::shared_ptr<torch::jit::Graph> graph) {
void replaceLinear(torch::jit::Block* block) {
// Define the decomposition function for aten::linear for the case where bias (mat2) is None.
static torch::jit::CompilationUnit decompose_funcs(R"SCRIPT(
def linear(self: Tensor, mat1: Tensor, mat2: Tensor):
return torch.matmul(self, mat1.t())
)SCRIPT");

// Iterate through nodes and search for aten::linear nodes where bias is not a Tensor (includes bias=None case)
auto block = graph->block();
// Define graph format for aten::linear with Tensor-type bias
std::string fused_linear = R"IR(
graph(%input, %weight, %bias):
%1: int = prim::Constant[value=1]()
%weight = aten::t(%weight)
%mm: Tensor = aten::matmul(%input, %weight)
%b_f: Tensor = trt::const(%bias)
%out: Tensor = aten::add(%b_f, %mm, %1)
return (%out))IR";

// Iterate through nodes in block, seaching for aten::linear
for (auto it = block->nodes().begin(); it != block->nodes().end(); it++) {
auto n = *it;
if (n->kind().toQualString() == std::string("aten::linear")) {

// Recursively explore nested blocks, such as those arising from prim::If
for (auto block : n->blocks()) {
replaceLinear(block);
}

if ((n->kind().toQualString() == std::string("aten::linear")) && (n->inputs().size() >= 3)) {
auto input_values = n->inputs();
// input_values[2] is the bias. If none, replace it with the decomposed linear graph.

// input_values[2] is the bias
// If Tensor, replace with fused-bias decomposed graph
// Otherwise, replace it with the no-bias decomposed linear graph.
if (input_values[2]->type()->isSubtypeOf(c10::TensorType::get())) {
continue;
torch::jit::WithInsertPoint guard(*it);

// Initialize new fused subgraph from IR code above
auto fused_g = std::make_shared<torch::jit::Graph>();
torch::jit::parseIR(fused_linear, fused_g.get());

// Insert subgraph in place of aten::linear, replacing inputs and outputs accordingly
torch::jit::Value* new_output = insertGraph(*it->owningGraph(), *fused_g, it->inputs()).at(0);
new_output->setType(it->output()->type());
it->output()->replaceAllUsesWith(new_output);
it.destroyCurrent();
} else {
torch::jit::WithInsertPoint guard(*it);

// Initialized decomposed graph without bias term
std::shared_ptr<torch::jit::Graph> d_graph = toGraphFunction(decompose_funcs.get_function("linear")).graph();
torch::jit::Value* new_output = insertGraph(*it->owningGraph(), *d_graph, it->inputs()).at(0);

// Insert function in place of aten::linear, replacing inputs and outputs accordingly
new_output->setType(it->output()->type());
it->output()->replaceAllUsesWith(new_output);
it.destroyCurrent();
Expand All @@ -45,27 +78,8 @@ void replaceLinearWithBiasNonePattern(std::shared_ptr<torch::jit::Graph> graph)
}

void LinearToAddMM(std::shared_ptr<torch::jit::Graph>& graph) {
// TensorRT implicitly adds a flatten layer infront of FC layers if necessary
std::string flatten_linear_pattern = R"IR(
graph(%input, %weight, %bias):
%res = aten::linear(%input, %weight, %bias)
return (%res))IR";

std::string fused_linear = R"IR(
graph(%input, %weight_t, %bias):
%1: int = prim::Constant[value=1]()
%weight = aten::t(%weight_t)
%mm: Tensor = aten::matmul(%input, %weight)
%b_f: Tensor = trt::const(%bias)
%out: Tensor = aten::add(%b_f, %mm, %1)
return (%out))IR";

// First find and replace aten::linear nodes with non-tensor bias values.
replaceLinearWithBiasNonePattern(graph);

torch::jit::SubgraphRewriter flatten_linear_to_linear;
flatten_linear_to_linear.RegisterRewritePattern(flatten_linear_pattern, fused_linear);
flatten_linear_to_linear.runOnGraph(graph);
// Recursively find and replace all instances of aten::linear with the corresponding decomposed form
replaceLinear(graph->block());
}

} // namespace passes
Expand Down
78 changes: 78 additions & 0 deletions tests/core/lowering/test_linear_to_addmm.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -57,3 +57,81 @@ TEST(LoweringPasses, LinearToAddMMBiasNone) {

ASSERT_TRUE(!torch::jit::findPatternMatches(*tg, *sg).empty());
}

TEST(LoweringPasses, LinearToAddMMBiasNoneGraphRun) {
std::string source_graph = R"IR(
graph(%input, %weight):
%biasNone : None = prim::Constant()
%true : bool = prim::Constant[value=1]()
%invalid_weight : Tensor = aten::t(%weight)
%4 : Tensor = prim::If(%true)
block0():
%res = aten::linear(%input, %weight, %biasNone)
-> (%res)
block1():
%res = aten::linear(%input, %invalid_weight, %biasNone)
-> (%res)
return (%4))IR";

// This regression test case ensures the Linear-to-AddMM lowering pass satisfies two constraints for non-Tensor bias:
// 1. It recursively resolves sub-blocks within the node, replacing sub-blocks to be converted as well
// 2. It does not pre-evaluate branches of the block which may have invalid operations

auto g = std::make_shared<torch::jit::Graph>();
torch::jit::parseIR(source_graph, g.get());

auto in_0 = at::rand({8, 7}, {at::kCUDA});
auto in_1 = at::rand({8, 7}, {at::kCUDA});

auto params = torch_tensorrt::core::ir::get_static_params(g->inputs(), {});
auto jit_results = torch_tensorrt::tests::util::RunGraph(g, params, {in_0, in_1});

torch_tensorrt::core::lowering::passes::LinearToAddMM(g);

LOG_DEBUG(g);

params = torch_tensorrt::core::ir::get_static_params(g->inputs(), {});
auto trt_results = torch_tensorrt::tests::util::RunGraphEngine(g, params, {in_0, in_1});

ASSERT_TRUE(
torch_tensorrt::tests::util::almostEqual(jit_results[0], trt_results[0].reshape_as(jit_results[0]), 2e-6));
}

TEST(LoweringPasses, LinearToAddMMBiasGraphRun) {
std::string source_graph = R"IR(
graph(%input, %weight, %bias):
%true : bool = prim::Constant[value=1]()
%invalid_weight : Tensor = aten::t(%weight)
%4 : Tensor = prim::If(%true)
block0():
%res = aten::linear(%input, %weight, %bias)
-> (%res)
block1():
%res = aten::linear(%input, %invalid_weight, %bias)
-> (%res)
return (%4))IR";

// This regression test case ensures the Linear-to-AddMM lowering pass satisfies two constraints for Tensor bias:
// 1. It recursively resolves sub-blocks within the node, replacing sub-blocks to be converted as well
// 2. It does not pre-evaluate branches of the block which may have invalid operations

auto g = std::make_shared<torch::jit::Graph>();
torch::jit::parseIR(source_graph, g.get());

auto in_0 = at::rand({8, 7}, {at::kCUDA});
auto in_1 = at::rand({8, 7}, {at::kCUDA});
auto in_2 = at::rand({8, 8}, {at::kCUDA});

auto params = torch_tensorrt::core::ir::get_static_params(g->inputs(), {});
auto jit_results = torch_tensorrt::tests::util::RunGraph(g, params, {in_0, in_1, in_2});

torch_tensorrt::core::lowering::passes::LinearToAddMM(g);

LOG_DEBUG(g);

params = torch_tensorrt::core::ir::get_static_params(g->inputs(), {});
auto trt_results = torch_tensorrt::tests::util::RunGraphEngine(g, params, {in_0, in_1, in_2});

ASSERT_TRUE(
torch_tensorrt::tests::util::almostEqual(jit_results[0], trt_results[0].reshape_as(jit_results[0]), 2e-6));
}

0 comments on commit 3e422f5

Please sign in to comment.