Skip to content

Commit

Permalink
Wenzhe/bmm add (#407)
Browse files Browse the repository at this point in the history
* add bmm_add fusion and test

* allow get_params to take binary kind
refactor fuse_binary

* add filter for bmm_add
  • Loading branch information
wenzhe-nrv authored Mar 15, 2022
1 parent b5f7770 commit d1379aa
Show file tree
Hide file tree
Showing 8 changed files with 124 additions and 0 deletions.
16 changes: 16 additions & 0 deletions intel_extension_for_pytorch/csrc/cpu/ideep/ideep/attributes.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,14 @@ struct attr_t : public dnnl::primitive_attr {
return attr;
}

static attr_t fuse_binary(algorithm alg, memory::desc src_desc) {
attr_t attr;
post_ops po;
po.append_binary(alg, src_desc);
attr.set_post_ops(po);
return attr;
}

static attr_t fuse_relu(
float scale = 1.0,
float alpha = 0.f,
Expand Down Expand Up @@ -162,6 +170,7 @@ struct attr_t : public dnnl::primitive_attr {

algorithm alg = algorithm::undef;
float scale = 1.0, alpha = 1.0, beta = 0.0;
memory::desc binary_src_desc;

auto akind = po.kind(index);
switch (akind) {
Expand All @@ -171,6 +180,9 @@ struct attr_t : public dnnl::primitive_attr {
case kind::eltwise:
po.get_params_eltwise(index, scale, alg, alpha, beta);
break;
case kind::binary:
po.get_params_binary(index, alg, binary_src_desc);
break;
default:
error::wrap_c_api(dnnl_invalid_arguments, "could not get params");
break;
Expand Down Expand Up @@ -243,6 +255,10 @@ struct attr_t : public dnnl::primitive_attr {
utils::to_bytes(bytes, beta);
bytes.append(1, '.');
utils::to_bytes(bytes, alg);
case kind::binary:
utils::to_bytes(bytes, akind);
bytes.append(1, '.');
utils::to_bytes(bytes, alg);
default:
break;
}
Expand Down
23 changes: 23 additions & 0 deletions intel_extension_for_pytorch/csrc/jit/cpu/kernels/Matmul.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -121,5 +121,28 @@ at::Tensor dil_matmul_div(
}
}

at::Tensor dil_bmm_add(
const at::Tensor& input,
const at::Tensor& batch1,
const at::Tensor& batch2,
const c10::Scalar& alpha) {
#if defined(IPEX_PROFILE_OP)
RECORD_FUNCTION("dil_bmm_add", std::vector<c10::IValue>({}));
#endif
auto batch1_dim = batch1.dim();
auto batch2_dim = batch2.dim();
if (batch1_dim == batch2_dim && batch1_dim >= 3) {
auto _input = input.is_contiguous() ? input : input.contiguous();
ideep::tensor onednn_input = itensor_view_from_dense(_input);

auto op_attr = ideep::attr_t::fuse_binary(
dnnl::algorithm::binary_add, onednn_input.get_desc());
return bmm_impl(
batch1, batch2, at::Tensor(), op_attr, {onednn_input}, 1.0f);
} else {
return at::baddbmm(input, batch1, batch2);
}
}

} // namespace cpu
} // namespace torch_ipex
7 changes: 7 additions & 0 deletions intel_extension_for_pytorch/csrc/jit/cpu/kernels/Matmul.h
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ namespace jit {
// So we fake some op namespaces to workaround that.
namespace ipex {
static auto matmul_div = Symbol::fromQualString("ipex::matmul_div");
static auto bmm_add = Symbol::fromQualString("ipex::bmm_add");

} // namespace ipex

Expand All @@ -36,5 +37,11 @@ at::Tensor dil_matmul_div(
at::Tensor out_opt,
const c10::Scalar& div_input);

at::Tensor dil_bmm_add(
const at::Tensor& input,
const at::Tensor& batch1,
const at::Tensor& batch2,
const c10::Scalar& alpha);

} // namespace cpu
} // namespace torch_ipex
36 changes: 36 additions & 0 deletions intel_extension_for_pytorch/csrc/jit/cpu/passes/graph_rewrite.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -432,6 +432,42 @@ void replaceInteractionWithQInteraction(std::shared_ptr<Graph>& graph) {
}
}

void fuseBmmAdd(std::shared_ptr<Graph>& graph) {
std::array<std::string, 2> add_operators = {"add", "add_"};

auto bmm_add_rstring_v1 = R"(
graph(%input, %batch1, %batch2, %alpha):
%x = aten::bmm(%batch1, %batch2)
%res = aten::add(%x, %input, %alpha)
return (%res))";
std::string bmm_add_fused = R"(
graph(%input, %batch1, %batch2, %alpha):
%res = ipex::bmm_add(%input, %batch1, %batch2, %alpha)
return (%res))";
// fliter the unsupported case
auto fusion_filter = [](const Match& match,
const std::unordered_map<std::string, Value*>& vmap) {
Node* node = match.anchor;
const auto& match_vmap = match.values_map;

auto batch1 = node->input(1)->type()->cast<TensorType>();
auto batch2 = node->input(2)->type()->cast<TensorType>();
if (batch1->dim() != batch2->dim()) {
return false;
}

if (batch1->dim().value() < 3) {
return false;
}

return true;
};

SubgraphRewriter rewriter_add_v1;
rewriter_add_v1.RegisterRewritePattern(bmm_add_rstring_v1, bmm_add_fused);
rewriter_add_v1.runOnGraph(graph, fusion_filter);
}

void FuseConcatBnRelu(std::shared_ptr<Graph>& graph) {
std::string aten_concat_bn_relu = R"(
graph(%input : Tensor[], %dim:int, %weight, %bias, %running_mean, %running_var, %training, %momentum, %eps, %cudnn_enabled):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,8 @@ void FuseShuffle(std::shared_ptr<Graph>& graph);
void FuseMHAScoreCalc(std::shared_ptr<Graph>& graph);
void FuseLinearSwishCustomized(std::shared_ptr<Graph>& graph);
void replaceAtenMaxPool2dWithIpexMaxPool2d(std::shared_ptr<Graph>& graph);
void fuseBmmAdd(std::shared_ptr<Graph>& graph);

void replaceOpsWithAtenInplaceOps(std::shared_ptr<Graph>& graph);
void replaceAtenOpsWithIpexInplaceOps(std::shared_ptr<Graph>& graph);
void replaceAtenSoftmaxWithIpexSoftmax(std::shared_ptr<Graph>& graph);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -550,6 +550,23 @@ RegisterOperators op({
},
aliasAnalysisFromSchema()),

Operator(
"ipex::bmm_add(Tensor input, Tensor batch1, Tensor batch2, Scalar alpha) -> "
"Tensor",
[](const Node* node) -> Operation {
return [](Stack* stack) {
auto result = dil_bmm_add(
(std::move(peek(stack, 0, 4))).toTensor(),
(std::move(peek(stack, 1, 4))).toTensor(),
(std::move(peek(stack, 2, 4))).toTensor(),
(std::move(peek(stack, 3, 4))).toScalar());
drop(stack, 4);
pack(stack, std::move(result));
return 0;
};
},
aliasAnalysisFromSchema()),

Operator(
"ipex::mha_scores_calc(Tensor q, Tensor k, Tensor rel_qk, Scalar "
"alpha, "
Expand Down
3 changes: 3 additions & 0 deletions intel_extension_for_pytorch/csrc/jit/fusion_pass.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -355,6 +355,9 @@ void IPEXFusionPass(std::shared_ptr<Graph>& graph) {
// Multi-Head-Attention
graph_rewrite::FuseMHAScoreCalc(graph);

// Fuse bmm + add for bmm_add
graph_rewrite::fuseBmmAdd(graph);

// Replace _convolution with conv2d or conv3d
graph_rewrite_helper::replaceConvolutionWithAtenConv(graph);

Expand Down
20 changes: 20 additions & 0 deletions tests/cpu/test_jit.py
Original file line number Diff line number Diff line change
Expand Up @@ -621,6 +621,15 @@ def forward(self, x):
else:
return mm_res.div(torch.ones(mm_res_shape,dtype=x.dtype)+1)

class BmmAdd(nn.Module):
def __init__(self):
super(BmmAdd, self).__init__()

def forward(self, input, batch1, batch2):
bmm_res = torch.bmm(batch1, batch2)
res = torch.add(bmm_res, input)
return res

class MHAScoresCalculation(nn.Module):
def __init__(self, dim_per_head, softmax_dim=-1):
super(MHAScoresCalculation, self).__init__()
Expand Down Expand Up @@ -2476,6 +2485,17 @@ def test_matmul_div(self):
kind_not_in_graph=None,
prec=5e-3)

def test_bmm_add(self):
M = torch.randn(10, 3, 5)
batch1 = torch.randn(10, 3, 4)
batch2 = torch.randn(10, 4, 5)
mod = BmmAdd()
traced_mod = torch.jit.trace(mod, (M, batch1, batch2))
fused_mod = traced_mod.graph_for(M, batch1, batch2)
out = traced_mod(M, batch1, batch2)
expected = torch.baddbmm(M, batch1, batch2)
self.assertTrue(torch.allclose(out, expected))

def test_ipex_softmax(self):
self._test_output(
AtenSoftmaxRepalce(),
Expand Down

0 comments on commit d1379aa

Please sign in to comment.