Skip to content

Commit

Permalink
Add linear+tanh fusion for inference (#685)
Browse files Browse the repository at this point in the history
* init linear tanh fusion

* add ut
  • Loading branch information
jianan-gu authored Apr 13, 2022
1 parent 3b6cb10 commit f0f2bae
Show file tree
Hide file tree
Showing 6 changed files with 97 additions and 1 deletion.
11 changes: 11 additions & 0 deletions intel_extension_for_pytorch/csrc/cpu/ideep/ideep/attributes.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -95,6 +95,17 @@ struct attr_t : public dnnl::primitive_attr {
return attr;
}

static attr_t fuse_tanh(
float scale = 1.0,
float alpha = 0.f,
float beta = 0.f) {
attr_t attr;
post_ops po;
po.append_eltwise(scale, algorithm::eltwise_tanh, alpha, beta);
attr.set_post_ops(po);
return attr;
}

static attr_t fuse_elu(
float scale = 1.0,
float alpha = 0.f,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,15 @@ at::Tensor linear_gelu_run(
return op_context->run(input, ideep::attr_t::fuse_gelu());
}

at::Tensor linear_tanh_run(
const at::Tensor& input,
const c10::intrusive_ptr<LinearOpContext>& op_context) {
IPEX_RECORD_FUNCTION(
"ipex_prepack::linear_tanh_run", std::vector<c10::IValue>({}));

return op_context->run(input, ideep::attr_t::fuse_tanh());
}

at::Tensor linear_sigmoid_run(
const at::Tensor& input,
const c10::intrusive_ptr<LinearOpContext>& op_context) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,10 @@ at::Tensor linear_gelu_run(
const at::Tensor& input,
const c10::intrusive_ptr<LinearOpContext>& op_context);

at::Tensor linear_tanh_run(
const at::Tensor& input,
const c10::intrusive_ptr<LinearOpContext>& op_context);

at::Tensor linear_sigmoid_run(
const at::Tensor& input,
const c10::intrusive_ptr<LinearOpContext>& op_context);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -101,11 +101,12 @@ void insertPrePackedLinearOp(std::shared_ptr<Graph>& graph) {

void fuseLinearWithEltwise(std::shared_ptr<Graph>& graph) {
SubgraphRewriter rewriter_relu, rewriter_gelu, rewriter_silu,
rewriter_sigmoid, rewriter_swish;
rewriter_sigmoid, rewriter_swish, rewriter_tanh;
std::array<std::string, 2> relu_operators = {"relu", "relu_"};
std::array<std::string, 2> sigmoid_operators = {"sigmoid", "sigmoid_"};
std::array<std::string, 2> silu_operators = {"silu", "silu_"};
std::array<std::string, 2> mul_operators = {"mul", "mul_"};
std::array<std::string, 2> tanh_operators = {"tanh", "tanh_"};

auto linear_relu_rstring = CodeTemplate(R"(
graph(%input, %weight, %bias, %out_features:int, %in_features:int, %batch_size:int, %weight_is_prepacked:bool):
Expand All @@ -120,6 +121,19 @@ void fuseLinearWithEltwise(std::shared_ptr<Graph>& graph) {
%res = ipex_prepack::linear_relu_run(%input, %packed_weight)
return (%res))";

auto linear_tanh_rstring = CodeTemplate(R"(
graph(%input, %weight, %bias, %out_features:int, %in_features:int, %batch_size:int, %weight_is_prepacked:bool):
%packed_weight = ipex_prepack::linear_prepack(%weight, %bias, %out_features, %in_features, %batch_size, %weight_is_prepacked)
%x = ipex_prepack::linear_run(%input, %packed_weight)
%res = aten::${tanh}(%x)
return (%res))");

std::string linear_tanh_fused = R"(
graph(%input, %weight, %bias, %out_features:int, %in_features:int, %batch_size:int, %weight_is_prepacked:bool):
%packed_weight = ipex_prepack::linear_prepack(%weight, %bias, %out_features, %in_features, %batch_size, %weight_is_prepacked)
%res = ipex_prepack::linear_tanh_run(%input, %packed_weight)
return (%res))";

std::string linear_gelu = R"(
graph(%input, %weight, %bias, %out_features:int, %in_features:int, %batch_size:int, %weight_is_prepacked:bool):
%packed_weight = ipex_prepack::linear_prepack(%weight, %bias, %out_features, %in_features, %batch_size, %weight_is_prepacked)
Expand Down Expand Up @@ -174,6 +188,13 @@ void fuseLinearWithEltwise(std::shared_ptr<Graph>& graph) {
linear_relu_rstring.format(env), linear_relu_fused);
}

for (const auto& tanh : tanh_operators) {
TemplateEnv env;
env.s("tanh", tanh);
rewriter_tanh.RegisterRewritePattern(
linear_tanh_rstring.format(env), linear_tanh_fused);
}

for (const auto& silu : silu_operators) {
TemplateEnv env;
env.s("silu", silu);
Expand All @@ -198,6 +219,7 @@ void fuseLinearWithEltwise(std::shared_ptr<Graph>& graph) {
rewriter_gelu.RegisterRewritePattern(linear_gelu, linear_gelu_fused);

rewriter_relu.runOnGraph(graph);
rewriter_tanh.runOnGraph(graph);
rewriter_gelu.runOnGraph(graph);
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -411,6 +411,24 @@ RegisterOperators op({
};
},
aliasAnalysisFromSchema()),

Operator(
"ipex_prepack::linear_tanh_run(Tensor input, "
"__torch__.torch.classes.ipex_prepack.LinearOpContext W_prepack) "
"-> Tensor",
[](const Node* node) -> Operation {
return [](Stack* stack) {
auto result = linear_tanh_run(
(std::move(peek(stack, 0, 2))).toTensor(),
(std::move(peek(stack, 1, 2)))
.toCustomClass<LinearOpContext>());
drop(stack, 2);
pack(stack, std::move(result));
return 0;
};
},
aliasAnalysisFromSchema()),

Operator(
"ipex_prepack::linear_sigmoid_run(Tensor input, "
"__torch__.torch.classes.ipex_prepack.LinearOpContext W_prepack) "
Expand Down
32 changes: 32 additions & 0 deletions tests/cpu/test_jit.py
Original file line number Diff line number Diff line change
Expand Up @@ -366,6 +366,16 @@ def __init__(self, in_channels, out_channels, **kwargs):
def forward(self, x):
return F.gelu(self.linear(x))

class LinearTanh(nn.Module):
def __init__(self, in_channels, out_channels, **kwargs):
super(LinearTanh, self).__init__()
seed = 2018
torch.manual_seed(seed)
self.linear = nn.Linear(in_channels, out_channels, **kwargs)

def forward(self, x):
return F.tanh(self.linear(x))

class LinearSigmoid(nn.Module):
def __init__(self, in_channels, out_channels, **kwargs):
super(LinearSigmoid, self).__init__()
Expand Down Expand Up @@ -2334,6 +2344,28 @@ def test_output_linear_gelu(self):
kind_not_in_graph="ipex_prepack::linear_prepack",
prec=5e-3)

def test_output_linear_tanh(self):
self._test_output(
LinearTanh(3, 32, bias=True),
torch.rand(32, 3),
kind_in_graph="aten::linear")
self._test_output_bf16(
LinearTanh(3, 32, bias=True),
torch.rand(32, 3),
kind_in_graph="ipex_prepack::linear_tanh_run",
kind_not_in_graph="ipex_prepack::linear_prepack",
prec=5e-3)
self._test_output(
LinearTanh(3, 32, bias=False),
torch.rand(32, 3),
kind_in_graph="aten::linear")
self._test_output_bf16(
LinearTanh(3, 32, bias=False),
torch.rand(32, 3),
kind_in_graph="ipex_prepack::linear_tanh_run",
kind_not_in_graph="ipex_prepack::linear_prepack",
prec=5e-3)

def test_output_linear_swish(self):
def _test_onednn_fp32(model, input, kind_in_graph, prec=5e-3):
model = model.eval()
Expand Down

0 comments on commit f0f2bae

Please sign in to comment.