Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Matmul] Add matmul op #8234

Merged
merged 29 commits into from
Jun 30, 2021
Merged
Show file tree
Hide file tree
Changes from 25 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
26 changes: 26 additions & 0 deletions include/tvm/relay/attrs/nn.h
Original file line number Diff line number Diff line change
Expand Up @@ -961,6 +961,32 @@ struct AvgPool3DAttrs : public tvm::AttrsNode<AvgPool3DAttrs> {
}
};

/*! \brief Attributes for matmul operator */
struct MatmulAttrs : public tvm::AttrsNode<MatmulAttrs> {
IndexExpr units;
DataType out_dtype;
bool transpose_a;
bool transpose_b;
tvm::String auto_scheduler_rewritten_layout; // The layout after auto-scheduler's layout rewrite
jcf94 marked this conversation as resolved.
Show resolved Hide resolved

TVM_DECLARE_ATTRS(MatmulAttrs, "relay.attrs.MatmulAttrs") {
TVM_ATTR_FIELD(units).describe("Number of hidden units of the dense transformation.");

// use 0 bits to indicate none.
TVM_ATTR_FIELD(out_dtype)
.set_default(NullValue<DataType>())
.describe("Output data type, set to explicit type under mixed precision setting");

TVM_ATTR_FIELD(transpose_a)
.set_default(false)
.describe("Whether the first input tensor is in transposed format.");

TVM_ATTR_FIELD(transpose_b)
.set_default(false)
.describe("Whether the second input tensor is in transposed format.");
}
};

/*! \brief Attributes for dense operator */
struct DenseAttrs : public tvm::AttrsNode<DenseAttrs> {
IndexExpr units;
Expand Down
19 changes: 18 additions & 1 deletion python/tvm/relay/frontend/tensorflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,16 @@

__all__ = ["from_tensorflow"]

# The default configurations of Relay TensorFlow frontend.
TF_DEFAULT_CONFIGS = {
# By default, TVM converts `tf.matmul` to `transpose(weight) + nn.dense`, which introduces
# unnecessary overhead in weight transpose. Change this flag to False to directly convert to
# `nn.matmul` to get rid of the overhead.
# However, please note that `nn.matmul` is in experimental so it may have some performance
# issues.
"use_dense": True,
}

# compatible operators that do NOT require any conversion.
_identity_list = []

Expand Down Expand Up @@ -1204,7 +1214,7 @@ def from_tensorflow(self, graph, layout="NHWC", shape=None, outputs=None):
return func, self._params


def from_tensorflow(graph, layout="NHWC", shape=None, outputs=None):
def from_tensorflow(graph, layout="NHWC", shape=None, outputs=None, use_dense_op=True):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't think we should have a flag here. We should just commit to one codepath.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The problem is that we're not able to remove all the nn.dense at this moment and there's not enough AutoTVM template for nn.matmul.

So the use of nn.matmul can only be seen as a experimental feature. We should not change the default behavior in case this may affect those who are using nn.dense.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can't we use the dense schedules when A_transpose=false and B_transpose=true. Then we can convert all nn.dense to nn.matmul.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This PR already uses dense schedule for matmul_nt in the case of lowering to TOPI. On the other hand, as @jcf94 mentioned in the PR comment, doing so will affect much more places in the codebase and we better gradually convert them instead of in a single PR. It sounds reasonable to me.

"""Load tensorflow graph which is a python tensorflow graph object into relay.
The companion parameters will be handled automatically.

Expand All @@ -1222,6 +1232,11 @@ def from_tensorflow(graph, layout="NHWC", shape=None, outputs=None):
outputs : List of output tensor names (Optional)
if not specified then the last node is assumed as graph output.

use_dense_op : bool (Optional) = True
Ture to convert `tf.matmul` to `nn.dense`, else to `nn.matmul`.
The `nn.dense` op requires the data tensor to be non-transposed and weight tensor to be
transposed, may insert extra `transpose` to the original graph.

Returns
-------
mod : tvm.IRModule
Expand All @@ -1230,6 +1245,8 @@ def from_tensorflow(graph, layout="NHWC", shape=None, outputs=None):
params : dict of str to tvm.nd.NDArray
Dict of converted parameters stored in tvm.nd.NDArray format
"""
global TF_DEFAULT_CONFIGS
TF_DEFAULT_CONFIGS["use_dense"] = use_dense_op

g = GraphProto()
mod, params = g.from_tensorflow(graph, layout, shape, outputs)
Expand Down
20 changes: 15 additions & 5 deletions python/tvm/relay/frontend/tensorflow_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -1113,13 +1113,23 @@ def _impl(inputs, attr, params, mod):

def _matmul():
def _impl(inputs, attr, params, mod):
from .tensorflow import TF_DEFAULT_CONFIGS

channels = _infer_channels(inputs[1], not attr["transpose_b"])
if attr["transpose_a"]:
inputs[0] = _op.transpose(inputs[0], axes=(1, 0))
if not attr["transpose_b"]:
inputs[1] = _op.transpose(inputs[1], axes=(1, 0))
if TF_DEFAULT_CONFIGS["use_dense"]:
if attr["transpose_a"]:
inputs[0] = _op.transpose(inputs[0], axes=(1, 0))
if not attr["transpose_b"]:
inputs[1] = _op.transpose(inputs[1], axes=(1, 0))
return AttrCvt(
op_name="dense",
extras={"units": channels},
ignores=["transpose_a", "transpose_b", "T"],
)(inputs, attr)
return AttrCvt(
op_name="dense", extras={"units": channels}, ignores=["transpose_a", "transpose_b", "T"]
op_name="matmul",
extras={"units": channels},
ignores=["T"],
)(inputs, attr)

return _impl
Expand Down
29 changes: 29 additions & 0 deletions python/tvm/relay/op/_tensor_grad.py
Original file line number Diff line number Diff line change
Expand Up @@ -554,6 +554,35 @@ def dense_grad(orig, grad):
]


@register_gradient("nn.matmul")
def matmul_grad(orig, grad):
"""Returns [grad' @ tensor_b, tensor_a @ grad']"""
tensor_a, tensor_b = orig.args
if (orig.attrs["transpose_a"], orig.attrs["transpose_b"]) == (True, True):
return [
collapse_sum_like(
_nn.matmul(tensor_b, grad, transpose_a=True, transpose_b=True), tensor_a
),
collapse_sum_like(
_nn.matmul(grad, tensor_a, transpose_a=True, transpose_b=True), tensor_b
),
]
if (orig.attrs["transpose_a"], orig.attrs["transpose_b"]) == (True, False):
return [
collapse_sum_like(_nn.matmul(tensor_b, grad, transpose_b=True), tensor_a),
collapse_sum_like(_nn.matmul(tensor_a, grad), tensor_b),
]
if (orig.attrs["transpose_a"], orig.attrs["transpose_b"]) == (False, True):
# Keep using Dense op here for not involving extra ops
# TODO(jcf94): Merge all to nn.matmul when it is finally ready
return dense_grad(orig, grad)
# (orig.attrs["transpose_a"], orig.attrs["transpose_b"]) == (False, False)
return [
collapse_sum_like(_nn.matmul(grad, tensor_b, transpose_b=True), tensor_a),
collapse_sum_like(_nn.matmul(tensor_a, grad, transpose_a=True), tensor_b),
]


@register_gradient("nn.batch_matmul")
def batch_matmul_grad(orig, grad):
"""gradient for nn.batch_matmul: in einsum LHS_bik,RHS_bjk->RES_bij
Expand Down
63 changes: 56 additions & 7 deletions python/tvm/relay/op/nn/_nn.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,32 @@
reg.register_pattern("nn.log_softmax", OpPattern.OPAQUE)


@reg.register_legalize("nn.matmul")
def legalize_matmul(attrs, inputs, types):
"""Legalize matmul op.
FrozenGene marked this conversation as resolved.
Show resolved Hide resolved

Parameters
----------
attrs : tvm.ir.Attrs
Attributes of current matmul
inputs : list of tvm.relay.Expr
The args of the Relay expr to be legalized
types : list of types
List of input and output types

Returns
-------
result : tvm.relay.Expr
The legalized expr
"""
return topi.nn.matmul_legalize(attrs, inputs, types)


# matmul
reg.register_strategy("nn.matmul", strategy.matmul_strategy)
reg.register_pattern("nn.matmul", reg.OpPattern.OUT_ELEMWISE_FUSABLE)


@reg.register_legalize("nn.dense")
def legalize_dense(attrs, inputs, types):
"""Legalize dense op.
Expand Down Expand Up @@ -1160,21 +1186,44 @@ def batch_flatten_shape_func(attrs, inputs, _):


@script
def _dense_shape_func(data_shape, weight_shape):
out = output_tensor((data_shape.shape[0],), "int64")
def _matmul_shape_func(tensor_a_shape, tensor_b_shape, transpose_a, transpose_b):
out = output_tensor((tensor_a_shape.shape[0],), "int64")
for i in const_range(out.shape[0] - 1):
out[i] = data_shape[i]
out[out.shape[0] - 1] = weight_shape[0]
out[i] = tensor_a_shape[i]
if transpose_a:
out[out.shape[0] - 2] = out[out.shape[0] - 1]
out[out.shape[0] - 1] = tensor_b_shape[0] if transpose_b else tensor_b_shape[1]

return out


@reg.register_shape_func("nn.matmul", False)
def matmul_shape_func(attrs, inputs, _):
"""Shape function for matmul op."""
ret = [
_matmul_shape_func(
inputs[0],
inputs[1],
expr.IntImm("bool", attrs.transpose_a),
expr.IntImm("bool", attrs.transpose_b),
)
]
return ret


@reg.register_shape_func("nn.dense", False)
def dense_shape_func(attrs, inputs, _):
"""Shape function for dense op. This is an alias of matmul_nt operator for data tensor in
non-transposed format and weight tensor in transposed format.
"""
Shape function for dense op.
"""
ret = [_dense_shape_func(inputs[0], inputs[1])]
ret = [
_matmul_shape_func(
inputs[0],
inputs[1],
expr.IntImm("bool", False),
expr.IntImm("bool", True),
)
]
return ret


Expand Down
44 changes: 44 additions & 0 deletions python/tvm/relay/op/nn/nn.py
Original file line number Diff line number Diff line change
Expand Up @@ -1471,6 +1471,50 @@ def bias_add(data, bias, axis=1):
return _make.bias_add(data, bias, axis)


def matmul(tensor_a, tensor_b, units=None, out_dtype="", transpose_a=False, transpose_b=False):
"""Matmul operator.
Applies a linear transformation. The A & B can be transposed.

.. math::

`C = A * B`

Parameters
----------
data : tvm.relay.Expr
The first input of the operator,
of shape `(d_1, d_2, ..., d_n, units_in)` or `(d_1, d_2, ..., units_in, d_n)`.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Shouldn't both input shapes by dimension 2?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No, the input of matmul is supposed to be a multiple-dim tensor(not limited to 2). This is copied from the original nn.dense.

Other frameworks like Pytorch also has such definition.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you update the definition of the computation above to reflect these shapes then?


weight : tvm.relay.Expr
The second input expressions, 2-D matrix,
of shape `(units_in, units)` or `(units, units_in)`.

units : Optional[int]
Number of hidden units of the matmul transformation.

out_dtype : Optional[str]
Specifies the output data type for mixed precision matmul,
of shape `(d_1, d_2, ..., d_n, units)`.

transpose_a : Optional[bool] = False
Whether the data tensor is in transposed format.

transpose_b : Optional[bool] = False
Whether the weight tensor is in transposed format.

Returns
-------
result : tvm.relay.Expr
The computed result.
"""
# Since currently `nn.dense` has better topi schedule support, will prefer to use `dense`
# rather than `matmul` for better compatibility
if not transpose_a and transpose_b:
# TODO(jcf94): Remove this when `nn.matmul` is finnaly ready
return dense(tensor_a, tensor_b, units, out_dtype)
return _make.matmul(tensor_a, tensor_b, units, out_dtype, transpose_a, transpose_b)


def dense(data, weight, units=None, out_dtype=""):
"""Dense operator.
Applies a linear transformation
Expand Down
5 changes: 5 additions & 0 deletions python/tvm/relay/op/op_attrs.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,11 @@ class BiasAddAttrs(Attrs):
"""Atttribute of nn.bias_add"""


@tvm._ffi.register_object("relay.attrs.MatmulAttrs")
class MatmulAttrs(Attrs):
"""Attributes for nn.matmul"""


@tvm._ffi.register_object("relay.attrs.DenseAttrs")
class DenseAttrs(Attrs):
"""Attributes for nn.dense"""
Expand Down
32 changes: 32 additions & 0 deletions python/tvm/relay/op/strategy/cuda.py
Original file line number Diff line number Diff line change
Expand Up @@ -698,6 +698,38 @@ def conv1d_transpose_strategy_cuda(attrs, inputs, out_type, target):
return strategy


@matmul_strategy.register(["cuda", "gpu"])
def matmul_strategy_cuda(attrs, inputs, out_type, target):
jcf94 marked this conversation as resolved.
Show resolved Hide resolved
"""Matmul cuda strategy."""
strategy = _op.OpStrategy()

if is_auto_scheduler_enabled():
strategy.add_implementation(
wrap_compute_matmul(topi.nn.matmul),
naive_schedule,
name="matmul.cuda",
)
else:
logger.warning(
"Matmul is not optimized for cuda. Recommend to use cublas for better performance."
)
# Temporary use this as a basic schedule
strategy.add_implementation(
wrap_compute_matmul(topi.cuda.matmul_default_cuda),
wrap_topi_schedule(topi.cuda.schedule_matmul_default_cuda),
name="matmul_default.cuda",
)

if target.kind.name == "cuda" and "cublas" in target.libs:
strategy.add_implementation(
wrap_compute_matmul(topi.cuda.matmul_cublas),
wrap_topi_schedule(topi.cuda.schedule_matmul_cublas),
name="matmul_cublas.cuda",
plevel=25,
)
return strategy


@dense_strategy.register(["cuda", "gpu"])
def dense_strategy_cuda(attrs, inputs, out_type, target):
"""dense cuda strategy"""
Expand Down
36 changes: 36 additions & 0 deletions python/tvm/relay/op/strategy/generic.py
Original file line number Diff line number Diff line change
Expand Up @@ -712,6 +712,42 @@ def dilation2d_strategy(attrs, inputs, out_type, target):
return strategy


# matmul
def wrap_compute_matmul(topi_compute, need_auto_scheduler_layout=False):
"""wrap matmul topi compute"""

def _compute_matmul(attrs, inputs, out_type):
"""Compute definition of matmul"""
out_dtype = attrs.out_dtype
out_dtype = inputs[0].dtype if out_dtype == "" else out_dtype
args = [
inputs[0],
inputs[1],
None,
out_dtype,
attrs.transpose_a,
attrs.transpose_b,
]
if need_auto_scheduler_layout:
args.append(get_auto_scheduler_rewritten_layout(attrs))
return [topi_compute(*args)]

return _compute_matmul


@override_native_generic_func("matmul_strategy")
def matmul_strategy(attrs, inputs, out_type, target):
"""matmul generic strategy"""
logger.warning("matmul is not optimized for this platform.")
strategy = _op.OpStrategy()
strategy.add_implementation(
wrap_compute_matmul(topi.nn.matmul),
wrap_topi_schedule(topi.generic.schedule_matmul),
name="matmul.generic",
)
return strategy


# dense
def wrap_compute_dense(topi_compute, need_auto_scheduler_layout=False):
"""wrap dense topi compute"""
Expand Down
Loading