Skip to content

Commit

Permalink
enable bmm (apache#12018)
Browse files Browse the repository at this point in the history
  • Loading branch information
billishyahao authored and masahi committed Jul 15, 2022
1 parent 19b6a23 commit ca14468
Show file tree
Hide file tree
Showing 3 changed files with 81 additions and 2 deletions.
4 changes: 3 additions & 1 deletion python/tvm/relay/op/contrib/dnnl.py
Original file line number Diff line number Diff line change
Expand Up @@ -105,6 +105,7 @@ def _func_wrapper(expr):
_register_external_op_helper("add")
_register_external_op_helper("multiply")
_register_external_op_helper("nn.layer_norm")
_register_external_op_helper("nn.batch_matmul")


def make_conv_pattern(conv_name, with_bias=True, with_eltwise=None):
Expand Down Expand Up @@ -563,6 +564,7 @@ def visit_call(self, call):
"nn.conv3d_transpose",
"nn.dense",
"nn.layer_norm",
"nn.batch_matmul",
]
)
if isinstance(call.op, tvm.tir.op.Op):
Expand Down Expand Up @@ -679,7 +681,7 @@ def __init__(self):
const_two = is_expr(relay.const(2)) | is_expr(relay.const(2.0))
p1 = is_op("power")(cdiff, const_two)
mp1 = is_op("mean")(p1) | is_op("variance")(self.data, mu)
eps = is_expr(relay.const(1e-5))
eps = is_expr(relay.const(1e-5)) | is_expr(relay.const(1e-6))
added_eps = is_op("add")(mp1, eps)
deno = is_op("sqrt")(added_eps)
div_out = is_op("divide")(diff, deno)
Expand Down
50 changes: 49 additions & 1 deletion src/runtime/contrib/dnnl/dnnl_json_runtime.cc
Original file line number Diff line number Diff line change
Expand Up @@ -269,6 +269,8 @@ class DNNLJSONRuntime : public JSONRuntimeBase {
Binary(nid, dnnl::algorithm::binary_mul);
} else if ("nn.layer_norm" == op_name) {
LayerNorm(nid);
} else if ("nn.batch_matmul" == op_name) {
BatchMatMul(nid);
} else {
LOG(FATAL) << "Unsupported op: " << op_name;
}
Expand Down Expand Up @@ -483,6 +485,52 @@ class DNNLJSONRuntime : public JSONRuntimeBase {
{sum_in_tr, DNNL_ARG_DST});
}

void BatchMatMul(const size_t& nid) {
auto node = nodes_[nid];

// Setup attributes.
auto src_tr = GetInput(nid, 0);
auto wgh_tr = GetInput(nid, 1);
auto dst_tr = GetOutput(nid, 0);
auto bias_tr = TensorRequisite{};

auto attr = ParseAttrs(nid, &bias_tr);
attr.set_scratchpad_mode(dnnl::scratchpad_mode::user);

bool transpose_a = GetNodeAttr<bool>(node, "transpose_a");
bool transpose_b = GetNodeAttr<bool>(node, "transpose_b");

if (transpose_a) {
src_tr = src_tr.Permute({0, 2, 1});
}
if (transpose_b) {
wgh_tr = wgh_tr.Permute({0, 2, 1});
}

// Assumption that bias is correct and can be squeezed to 1D
bias_tr = bias_tr.Reshape({dst_tr.dims()[1]});

// Matmul description.
auto bmm_desc = dnnl::matmul::desc(src_tr.LayoutAny().desc(), wgh_tr.LayoutAny().desc(),
bias_tr.LayoutAny().desc(), dst_tr.LayoutAny().desc());

// Enable elementwise post-ops.
auto bmm_prim_desc = dnnl::matmul::primitive_desc(bmm_desc, attr, engine_);

src_tr = src_tr.RequestLayout(bmm_prim_desc.src_desc());
wgh_tr = wgh_tr.RequestLayout(bmm_prim_desc.weights_desc());
dst_tr = dst_tr.RequestLayout(bmm_prim_desc.dst_desc());
bias_tr = bias_tr.RequestLayout(bmm_prim_desc.bias_desc());

auto scratchpad_tr = TensorRequisite::AsIs(bmm_prim_desc.scratchpad_desc());

Submit(dnnl::matmul(bmm_prim_desc), {{DNNL_ARG_SRC, src_tr},
{DNNL_ARG_WEIGHTS, wgh_tr},
{DNNL_ARG_BIAS, bias_tr},
{DNNL_ARG_SCRATCHPAD, scratchpad_tr},
{DNNL_ARG_DST, dst_tr}});
}

void BatchNorm(const size_t& nid) {
auto node = nodes_[nid];

Expand Down Expand Up @@ -755,7 +803,6 @@ class DNNLJSONRuntime : public JSONRuntimeBase {

TensorRequisite GetOutput(const size_t& nid, const int idx) {
if (idx == -1) return {}; // -1 reserved value for empty input.

const JSONGraphNode& node = nodes_[nid];

ICHECK_LT(idx, node.GetNumOutput());
Expand All @@ -764,6 +811,7 @@ class DNNLJSONRuntime : public JSONRuntimeBase {
auto eid = node_row_ptr_[nid] + static_cast<uint32_t>(idx);

ICHECK(data_entry_[eid] == nullptr);

auto desc = MakePlainDesc(shape, dtype);

return TensorRequisite::AsIs(desc, eid).Backward();
Expand Down
29 changes: 29 additions & 0 deletions tests/python/contrib/test_dnnl.py
Original file line number Diff line number Diff line change
Expand Up @@ -556,6 +556,35 @@ def get_dense(
return out, dic, param_lst


def get_bmm(
x_shape=(1, 16, 8), k_shape=(1, 4, 8), dtype="float32", transpose_a=False, transpose_b=True
):
x = relay.var("x", shape=(x_shape), dtype=dtype)
kernel = relay.var("kernel", shape=(k_shape), dtype=dtype)
out = relay.nn.batch_matmul(
x, kernel, out_dtype=dtype, transpose_a=transpose_a, transpose_b=transpose_b
)
dic = {"x": x_shape, "kernel": k_shape}
param_lst = ["kernel"]
return out, dic, param_lst


def test_bmm(run_module, dtype="float32"):
x_shape = (1, 2, 4)
k_shape = (1, 3, 4)

dense, dic, param_lst = get_bmm(x_shape, k_shape, dtype=dtype)
dense = tvm.IRModule.from_expr(dense)
config = dense, dic, param_lst
run_and_verify_func(config, run_module=run_module, dtype=dtype)

k_shape_t = (1, 4, 3)
dense, dic, param_lst = get_bmm(x_shape, k_shape_t, dtype=dtype, transpose_b=False)
dense = tvm.IRModule.from_expr(dense)
config = dense, dic, param_lst
run_and_verify_func(config, run_module=run_module, dtype=dtype)


def get_dense_bias(
x_shape=(1, 16),
k_shape=(32, 16),
Expand Down

0 comments on commit ca14468

Please sign in to comment.