Skip to content

Commit

Permalink
Add blas support for matmul
Browse files Browse the repository at this point in the history
  • Loading branch information
jcf94 committed Jun 28, 2021
1 parent 49ffbd1 commit 5957bb3
Show file tree
Hide file tree
Showing 2 changed files with 98 additions and 24 deletions.
27 changes: 27 additions & 0 deletions python/tvm/relay/op/strategy/x86.py
Original file line number Diff line number Diff line change
Expand Up @@ -389,6 +389,33 @@ def matmul_strategy_cpu(attrs, inputs, out_type, target):
name="matmul.generic",
)

same_type = inputs[0].dtype == inputs[1].dtype == out_type.dtype
dtype = inputs[0].dtype
u8s8s32 = dtype == "uint8" and inputs[1].dtype == "int8" and out_type.dtype == "int32"
if "cblas" in target.libs:
with SpecializedCondition(same_type and dtype in ["float32", "float64"]):
strategy.add_implementation(
wrap_compute_matmul(topi.x86.matmul_cblas),
wrap_topi_schedule(topi.x86.schedule_matmul_cblas),
name="matmul_cblas.x86",
plevel=13,
)
if "mkl" in target.libs:
with SpecializedCondition(same_type and dtype in ["float32", "float64"] or u8s8s32):
strategy.add_implementation(
wrap_compute_matmul(topi.x86.matmul_mkl),
wrap_topi_schedule(topi.x86.schedule_matmul_mkl),
name="matmul_mkl.x86",
plevel=14,
)
if "mkldnn" in target.libs:
with SpecializedCondition(same_type and dtype == "float32"):
strategy.add_implementation(
wrap_compute_matmul(topi.x86.matmul_mkldnn),
wrap_topi_schedule(topi.x86.schedule_matmul_mkldnn),
name="matmul_mkldnn.x86",
plevel=15,
)
return strategy


Expand Down
95 changes: 71 additions & 24 deletions python/tvm/topi/x86/dense.py
Original file line number Diff line number Diff line change
Expand Up @@ -281,72 +281,119 @@ def _callback(op):
return s


def dense_blas_common(cfg, data, weight, bias, out_dtype, lib):
"""Compute dense using a BLAS library"""
def matmul_blas_common(cfg, data, weight, bias, out_dtype, data_transposed, weight_transposed, lib):
"""Compute matmul/dense using a BLAS library"""
M, K = get_const_tuple(data.shape)
N, _ = get_const_tuple(weight.shape)
if isinstance(M, int) and isinstance(K, int) and isinstance(N, int):
cfg.add_flop(M * K * N * 2)
if data.dtype == "uint8" and weight.dtype == "int8" and out_dtype == "int32":
if not hasattr(lib, "matmul_u8s8s32"):
raise NotImplementedError(
f"Dense with {lib.__name__} for {data.dtype} is not supported "
f"Matmul/Dense with {lib.__name__} for {data.dtype} is not supported "
"(matmulu8s8s32 not imlemented)"
)
C = lib.matmul_u8s8s32(data, weight, False, True, dtype=out_dtype)
C = lib.matmul_u8s8s32(data, weight, data_transposed, weight_transposed, dtype=out_dtype)
elif data.dtype == "float32" or data.dtype == "float64":
C = lib.matmul(data, weight, False, True)
C = lib.matmul(data, weight, data_transposed, weight_transposed)
else:
raise NotImplementedError(f"Dense with {lib.__name__} for {data.dtype} is not supported")
raise NotImplementedError(f"Matmul/Dense with {lib.__name__} for {data.dtype} is not supported")

if bias is not None:
C = te.compute(C.shape, lambda i, j: C[i, j] + bias[j].astype(out_dtype), tag=tag.BROADCAST)
return C


def schedule_matmul_blas_common(outs):
"""Default matmul schedule for BLAS library"""
s = te.create_schedule([x.op for x in outs])
te.schedule.AutoInlineInjective(s)

for out in outs:
if "dense" not in out.op.tag and "matmul" not in out.op.tag:
schedule_injective_from_existing(s, out)
return s


@autotvm.register_topi_compute("dense_cblas.x86")
def dense_cblas(cfg, data, weight, bias=None, out_dtype=None):
"""Compute dense using a cblas"""
return dense_blas_common(cfg, data, weight, bias, out_dtype, cblas)
return matmul_blas_common(cfg, data, weight, bias, out_dtype, False, True, cblas)


@autotvm.register_topi_schedule("dense_cblas.x86")
def schedule_dense_cblas(_, outs):
"""Create schedule for dense_cblas"""
return generic.schedule_extern(outs)
return schedule_matmul_blas_common(outs)


@autotvm.register_topi_compute("dense_mkl.x86")
def dense_mkl(cfg, data, weight, bias=None, out_dtype=None):
"""Compute dense using mkl"""
return dense_blas_common(cfg, data, weight, bias, out_dtype, mkl)
return matmul_blas_common(cfg, data, weight, bias, out_dtype, False, True, mkl)


@autotvm.register_topi_schedule("dense_mkl.x86")
def schedule_dense_mkl(_, outs):
"""Create schedule for dense_mkl"""
# return generic.schedule_extern(outs)
s = te.create_schedule([x.op for x in outs])
te.schedule.AutoInlineInjective(s)

def _callback(op):
if "broadcast" in op.tag or "injective" in op.tag or "elemwise" in op.tag:
schedule_injective_from_existing(s, op.output(0))

# traverse_inline(s, outs[0].op, _callback)
for out in outs:
if "dense" not in out.op.name:
schedule_injective_from_existing(s, out)
return s
return schedule_matmul_blas_common(outs)


@autotvm.register_topi_compute("dense_mkldnn.x86")
def dense_mkldnn(cfg, data, weight, bias=None, out_dtype=None):
"""Compute dense using mkldnn"""
return dense_blas_common(cfg, data, weight, bias, out_dtype, mkldnn)
return matmul_blas_common(cfg, data, weight, bias, out_dtype, False, True, mkldnn)


@autotvm.register_topi_schedule("dense_mkldnn.x86")
def schedule_dense_mkldnn(_, outs):
"""Create schedule for dense_mkldnn"""
return generic.schedule_extern(outs)
return schedule_matmul_blas_common(outs)


@autotvm.register_topi_compute("matmul_cblas.x86")
def matmul_cblas(
cfg, data, weight, bias=None, out_dtype=None, data_transposed=False, weight_transposed=False
):
"""Compute matmul using a cblas"""
return matmul_blas_common(
cfg, data, weight, bias, out_dtype, data_transposed, weight_transposed, cblas
)


@autotvm.register_topi_schedule("matmul_cblas.x86")
def schedule_matmul_cblas(_, outs):
"""Create schedule for matmul_cblas"""
return schedule_matmul_blas_common(outs)


@autotvm.register_topi_compute("matmul_mkl.x86")
def matmul_mkl(
cfg, data, weight, bias=None, out_dtype=None, data_transposed=False, weight_transposed=False
):
"""Compute matmul using mkl"""
return matmul_blas_common(
cfg, data, weight, bias, out_dtype, data_transposed, weight_transposed, mkl
)


@autotvm.register_topi_schedule("matmul_mkl.x86")
def schedule_matmul_mkl(_, outs):
"""Create schedule for matmul_mkl"""
return schedule_matmul_blas_common(outs)


@autotvm.register_topi_compute("matmul_mkldnn.x86")
def matmul_mkldnn(
cfg, data, weight, bias=None, out_dtype=None, data_transposed=False, weight_transposed=False
):
"""Compute matmul using mkldnn"""
return matmul_blas_common(
cfg, data, weight, bias, out_dtype, data_transposed, weight_transposed, mkldnn
)


@autotvm.register_topi_schedule("matmul_mkldnn.x86")
def schedule_matmul_mkldnn(_, outs):
"""Create schedule for matmul_mkldnn"""
return schedule_matmul_blas_common(outs)

0 comments on commit 5957bb3

Please sign in to comment.