Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix: Bugfix in Linear-to-AddMM Fusion Lowering Pass #1619

Merged
merged 1 commit into from
Feb 1, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
68 changes: 41 additions & 27 deletions core/lowering/passes/linear_to_addmm.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
#include "core/util/prelude.h"
#include "torch/csrc/jit/api/function_impl.h"
#include "torch/csrc/jit/ir/alias_analysis.h"
#include "torch/csrc/jit/ir/irparser.h"
#include "torch/csrc/jit/jit_log.h"
#include "torch/csrc/jit/passes/constant_propagation.h"
#include "torch/csrc/jit/passes/dead_code_elimination.h"
Expand All @@ -16,26 +17,58 @@ namespace core {
namespace lowering {
namespace passes {

void replaceLinearWithBiasNonePattern(std::shared_ptr<torch::jit::Graph> graph) {
void replaceLinear(torch::jit::Block* block) {
// Define the decomposition function for aten::linear for the case where bias (mat2) is None.
static torch::jit::CompilationUnit decompose_funcs(R"SCRIPT(
def linear(self: Tensor, mat1: Tensor, mat2: Tensor):
return torch.matmul(self, mat1.t())
)SCRIPT");

// Iterate through nodes and search for aten::linear nodes where bias is not a Tensor (includes bias=None case)
auto block = graph->block();
// Define graph format for aten::linear with Tensor-type bias
std::string fused_linear = R"IR(
graph(%input, %weight, %bias):
%1: int = prim::Constant[value=1]()
%weight = aten::t(%weight)
%mm: Tensor = aten::matmul(%input, %weight)
%b_f: Tensor = trt::const(%bias)
%out: Tensor = aten::add(%b_f, %mm, %1)
return (%out))IR";

// Iterate through nodes in block, seaching for aten::linear
for (auto it = block->nodes().begin(); it != block->nodes().end(); it++) {
auto n = *it;
if (n->kind().toQualString() == std::string("aten::linear")) {

// Recursively explore nested blocks, such as those arising from prim::If
for (auto block : n->blocks()) {
replaceLinear(block);
}

if ((n->kind().toQualString() == std::string("aten::linear")) && (n->inputs().size() >= 3)) {
auto input_values = n->inputs();
// input_values[2] is the bias. If none, replace it with the decomposed linear graph.

// input_values[2] is the bias
// If Tensor, replace with fused-bias decomposed graph
// Otherwise, replace it with the no-bias decomposed linear graph.
if (input_values[2]->type()->isSubtypeOf(c10::TensorType::get())) {
continue;
torch::jit::WithInsertPoint guard(*it);

// Initialize new fused subgraph from IR code above
auto fused_g = std::make_shared<torch::jit::Graph>();
torch::jit::parseIR(fused_linear, fused_g.get());

// Insert subgraph in place of aten::linear, replacing inputs and outputs accordingly
torch::jit::Value* new_output = insertGraph(*it->owningGraph(), *fused_g, it->inputs()).at(0);
new_output->setType(it->output()->type());
it->output()->replaceAllUsesWith(new_output);
it.destroyCurrent();
} else {
torch::jit::WithInsertPoint guard(*it);

// Initialized decomposed graph without bias term
std::shared_ptr<torch::jit::Graph> d_graph = toGraphFunction(decompose_funcs.get_function("linear")).graph();
torch::jit::Value* new_output = insertGraph(*it->owningGraph(), *d_graph, it->inputs()).at(0);

// Insert function in place of aten::linear, replacing inputs and outputs accordingly
new_output->setType(it->output()->type());
it->output()->replaceAllUsesWith(new_output);
it.destroyCurrent();
Expand All @@ -45,27 +78,8 @@ void replaceLinearWithBiasNonePattern(std::shared_ptr<torch::jit::Graph> graph)
}

void LinearToAddMM(std::shared_ptr<torch::jit::Graph>& graph) {
// TensorRT implicitly adds a flatten layer infront of FC layers if necessary
std::string flatten_linear_pattern = R"IR(
graph(%input, %weight, %bias):
%res = aten::linear(%input, %weight, %bias)
return (%res))IR";

std::string fused_linear = R"IR(
graph(%input, %weight_t, %bias):
%1: int = prim::Constant[value=1]()
%weight = aten::t(%weight_t)
%mm: Tensor = aten::matmul(%input, %weight)
%b_f: Tensor = trt::const(%bias)
%out: Tensor = aten::add(%b_f, %mm, %1)
return (%out))IR";

// First find and replace aten::linear nodes with non-tensor bias values.
replaceLinearWithBiasNonePattern(graph);

torch::jit::SubgraphRewriter flatten_linear_to_linear;
flatten_linear_to_linear.RegisterRewritePattern(flatten_linear_pattern, fused_linear);
flatten_linear_to_linear.runOnGraph(graph);
// Recursively find and replace all instances of aten::linear with the corresponding decomposed form
replaceLinear(graph->block());
}

} // namespace passes
Expand Down
78 changes: 78 additions & 0 deletions tests/core/lowering/test_linear_to_addmm.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -57,3 +57,81 @@ TEST(LoweringPasses, LinearToAddMMBiasNone) {

ASSERT_TRUE(!torch::jit::findPatternMatches(*tg, *sg).empty());
}

TEST(LoweringPasses, LinearToAddMMBiasNoneGraphRun) {
std::string source_graph = R"IR(
graph(%input, %weight):
%biasNone : None = prim::Constant()
%true : bool = prim::Constant[value=1]()
%invalid_weight : Tensor = aten::t(%weight)
%4 : Tensor = prim::If(%true)
block0():
%res = aten::linear(%input, %weight, %biasNone)
-> (%res)
block1():
%res = aten::linear(%input, %invalid_weight, %biasNone)
-> (%res)
return (%4))IR";

// This regression test case ensures the Linear-to-AddMM lowering pass satisfies two constraints for non-Tensor bias:
// 1. It recursively resolves sub-blocks within the node, replacing sub-blocks to be converted as well
// 2. It does not pre-evaluate branches of the block which may have invalid operations

auto g = std::make_shared<torch::jit::Graph>();
torch::jit::parseIR(source_graph, g.get());

auto in_0 = at::rand({8, 7}, {at::kCUDA});
auto in_1 = at::rand({8, 7}, {at::kCUDA});

auto params = torch_tensorrt::core::ir::get_static_params(g->inputs(), {});
auto jit_results = torch_tensorrt::tests::util::RunGraph(g, params, {in_0, in_1});

torch_tensorrt::core::lowering::passes::LinearToAddMM(g);

LOG_DEBUG(g);

params = torch_tensorrt::core::ir::get_static_params(g->inputs(), {});
auto trt_results = torch_tensorrt::tests::util::RunGraphEngine(g, params, {in_0, in_1});

ASSERT_TRUE(
torch_tensorrt::tests::util::almostEqual(jit_results[0], trt_results[0].reshape_as(jit_results[0]), 2e-6));
}

TEST(LoweringPasses, LinearToAddMMBiasGraphRun) {
std::string source_graph = R"IR(
graph(%input, %weight, %bias):
%true : bool = prim::Constant[value=1]()
%invalid_weight : Tensor = aten::t(%weight)
%4 : Tensor = prim::If(%true)
block0():
%res = aten::linear(%input, %weight, %bias)
-> (%res)
block1():
%res = aten::linear(%input, %invalid_weight, %bias)
-> (%res)
return (%4))IR";

// This regression test case ensures the Linear-to-AddMM lowering pass satisfies two constraints for Tensor bias:
// 1. It recursively resolves sub-blocks within the node, replacing sub-blocks to be converted as well
// 2. It does not pre-evaluate branches of the block which may have invalid operations

auto g = std::make_shared<torch::jit::Graph>();
torch::jit::parseIR(source_graph, g.get());

auto in_0 = at::rand({8, 7}, {at::kCUDA});
auto in_1 = at::rand({8, 7}, {at::kCUDA});
auto in_2 = at::rand({8, 8}, {at::kCUDA});

auto params = torch_tensorrt::core::ir::get_static_params(g->inputs(), {});
auto jit_results = torch_tensorrt::tests::util::RunGraph(g, params, {in_0, in_1, in_2});

torch_tensorrt::core::lowering::passes::LinearToAddMM(g);

LOG_DEBUG(g);

params = torch_tensorrt::core::ir::get_static_params(g->inputs(), {});
auto trt_results = torch_tensorrt::tests::util::RunGraphEngine(g, params, {in_0, in_1, in_2});

ASSERT_TRUE(
torch_tensorrt::tests::util::almostEqual(jit_results[0], trt_results[0].reshape_as(jit_results[0]), 2e-6));
}