Skip to content

Commit

Permalink
Implemented and tested full Trilu op.
Browse files Browse the repository at this point in the history
  • Loading branch information
jwfromm authored and Josh Fromm committed Jul 19, 2022
1 parent 2191179 commit 698e978
Show file tree
Hide file tree
Showing 11 changed files with 199 additions and 27 deletions.
9 changes: 9 additions & 0 deletions include/tvm/relay/attrs/transform.h
Original file line number Diff line number Diff line change
Expand Up @@ -575,6 +575,15 @@ struct StftAttrs : public tvm::AttrsNode<StftAttrs> {
}
}; // struct StftAttrs

struct TriluAttrs : public tvm::AttrsNode<TriluAttrs> {
bool upper;

TVM_DECLARE_ATTRS(TriluAttrs, "relay.attrs.TriluAttrs") {
TVM_ATTR_FIELD(upper).set_default(true).describe(
"Whether to keep the upper or lower half of the diagonal.");
}
}; // struct TriluAttrs

} // namespace relay
} // namespace tvm
#endif // TVM_RELAY_ATTRS_TRANSFORM_H_
15 changes: 15 additions & 0 deletions python/tvm/relay/frontend/onnx.py
Original file line number Diff line number Diff line change
Expand Up @@ -4685,6 +4685,20 @@ def _impl_v12(cls, inputs, attr, params):
return _op.einsum(inputs, equation)


class Trilu(OnnxOpConverter):
"""Operator converter for Trilu"""

@classmethod
def _impl_v14(cls, inputs, attr, params):
upper = attr.get("upper", True)
if len(inputs) == 2:
data, k = inputs
else:
data = inputs[0]
k = 0
return _op.trilu(data, k, upper)


class RandomNormal(OnnxOpConverter):
"""Operator converter for random_normal"""

Expand Down Expand Up @@ -5345,6 +5359,7 @@ def _get_convert_map(opset):
"CumSum": CumSum.get_converter(opset),
"Unique": Unique.get_converter(opset),
"Einsum": Einsum.get_converter(opset),
"Trilu": Trilu.get_converter(opset),
# defs/control_flow
"Loop": Loop.get_converter(opset),
"If": If.get_converter(opset),
Expand Down
4 changes: 4 additions & 0 deletions python/tvm/relay/op/_transform.py
Original file line number Diff line number Diff line change
Expand Up @@ -191,6 +191,10 @@ def stft_shape_func(attrs, inputs, _):
]


# trilu
_reg.register_strategy("trilu", strategy.trilu_strategy)


# scatter_add
@_reg.register_compute("scatter_add")
def compute_scatter_add(attrs, inputs, output_type):
Expand Down
5 changes: 5 additions & 0 deletions python/tvm/relay/op/op_attrs.py
Original file line number Diff line number Diff line change
Expand Up @@ -617,3 +617,8 @@ class NLLLossAttrs(Attrs):
@tvm._ffi.register_object("relay.attrs.FixedPointMultiplyAttrs")
class FixedPointMultiplyAttrs(Attrs):
"""Attributes used in fixed_point_multiply operators"""


@tvm._ffi.register_object("relay.attrs.TriluAttrs")
class TriluAttrs(Attrs):
"""Attributes used in trilu operators"""
28 changes: 28 additions & 0 deletions python/tvm/relay/op/strategy/generic.py
Original file line number Diff line number Diff line change
Expand Up @@ -1460,6 +1460,34 @@ def _compute_stft(attrs, inputs, output_type):
return _compute_stft


# trilu
@override_native_generic_func("trilu_strategy")
def trilu_strategy(attrs, outs, out_type, target):
"""trilu generic strategy"""
strategy = _op.OpStrategy()
strategy.add_implementation(
wrap_compute_trilu(topi.trilu),
wrap_topi_schedule(topi.generic.schedule_extern),
name="trilu.generic",
)
return strategy


def wrap_compute_trilu(topi_compute):
"""Wrap trilu compute"""

def _compute_trilu(attrs, inputs, output_type):
return [
topi_compute(
inputs[0],
inputs[1],
attrs.upper,
)
]

return _compute_trilu


# roi_pool
@generic_func
def schedule_roi_pool(attrs, outs, target):
Expand Down
43 changes: 43 additions & 0 deletions python/tvm/relay/op/transform.py
Original file line number Diff line number Diff line change
Expand Up @@ -1889,3 +1889,46 @@ def stft(
window = _make.ones([n_fft], "int32")

return _make.stft(data, n_fft, hop_length, win_length, window, normalized, onesided)


def trilu(data, k, upper=True):
"""
Given a 2-D matrix or batches of 2-D matrices, returns the
upper or lower triangular part of the tensor.
Parameters
----------
data: relay.Expr
The tensor that trilu will be applied to. Must be either
a 2D matrix or a tensor of batches of 2D matrices.
k: int
The number of diagonals above or below the main diagonal
to exclude or include.
upper: bool, optional
If True, only upper triangular values of input are kept,
if False, the lower triangular values are kept.
Returns
-------
ret : relay.Expr
The new tensor with appropriate diagonals set to zero.
Examples
--------
.. code-block:: python
x = [[0, 1, 2],
[3, 4, 5],
[6, 7, 8]]
relay.trilu(x, True, 0) =
[[0, 1, 2],
[0, 4, 5],
[0, 0, 8]]
"""
if not isinstance(k, Expr):
k = const(k, dtype="int32")
return _make.trilu(data, k, upper)
21 changes: 13 additions & 8 deletions python/tvm/topi/transform.py
Original file line number Diff line number Diff line change
Expand Up @@ -1003,24 +1003,25 @@ def sliding_window(data, axis, window_shape, strides):
return cpp.sliding_window(data, axis, window_shape, strides)


def trilu(x, upper, k):
def trilu(data, k, upper):
"""
Given a 2-D matrix or batches of 2-D matrices, returns the
upper or lower triangular part of the tensor.
Parameters
----------
x: tvm.te.Tensor
data: tvm.te.Tensor
The tensor that trilu will be applied to. Must be either
a 2D matrix or a tensor of batches of 2D matrices.
k: tvm.te.Tensor
The number of diagonals above or below the main diagonal
to exclude or include.
upper: bool
If True, only upper triangular values of input are kept,
if False, the lower triangular values are kept.
k: int
The number of diagonals above or below the main diagonal
to exclude or include.
Returns
-------
Expand All @@ -1040,6 +1041,10 @@ def trilu(x, upper, k):
[0, 4, 5],
[0, 0, 8]]
"""
# Make sure datatype is consistent.
if k.dtype != "int32":
k = tvm.tir.Cast("int32", k)

# Check either above or below diagonal depending on upper.
check_op = tvm.tir.GE
if upper:
Expand All @@ -1050,7 +1055,7 @@ def _apply_trilu(*indices):
col_index = indices[-1]
other_indices = indices[:-2]
check_position = check_op(row_index, col_index - k)
value = x(*other_indices, row_index, col_index)
return tvm.tir.Select(check_position, value, tvm.tir.const(0, x.dtype))
value = data(*other_indices, row_index, col_index)
return tvm.tir.Select(check_position, value, tvm.tir.const(0, data.dtype))

return te.compute(x.shape, _apply_trilu, name="trilu")
return te.compute(data.shape, _apply_trilu, name="trilu")
50 changes: 50 additions & 0 deletions src/relay/op/tensor/transform.cc
Original file line number Diff line number Diff line change
Expand Up @@ -4230,5 +4230,55 @@ RELAY_REGISTER_OP("invert_permutation")
.set_attr<TOpPattern>("TOpPattern", kInjective)
.set_attr<TOpIsStateful>("TOpIsStateful", false);

// Trilu

TVM_REGISTER_NODE_TYPE(TriluAttrs);

bool TriluRel(const Array<Type>& types, int num_inputs, const Attrs& attrs,
const TypeReporter& reporter) {
// types: [data, k, result]
ICHECK_EQ(types.size(), 3) << "Trilu: expect 3 types but " << types.size() << " provided";
ICHECK_EQ(num_inputs, 2) << "Trilu: expect 2 inputs but " << num_inputs << " provided";
auto data = types[0].as<TensorTypeNode>();
if (data == nullptr) {
ICHECK(types[0].as<IncompleteTypeNode>())
<< "Trilu: expect input type to be TensorType but get " << types[0];
return false;
}

auto k = types[1].as<TensorTypeNode>();
if (k == nullptr) {
ICHECK(types[1].as<IncompleteTypeNode>())
<< "Trilu: expect k type to be TensorType but get " << types[1];
return false;
}

ICHECK(k->shape.size() == 0) << "Trilu: k must be a 0-D tensor but get " << k;

// Output shape is the same as input shape.
reporter->Assign(types[2], TensorType(data->shape, data->dtype));
return true;
}

Expr MakeTrilu(Expr data, Expr k, bool upper) {
auto attrs = make_object<TriluAttrs>();
attrs->upper = upper;
static const Op& op = Op::Get("trilu");
return Call(op, {data, k}, Attrs(attrs), {});
}

TVM_REGISTER_GLOBAL("relay.op._make.trilu").set_body_typed(MakeTrilu);

RELAY_REGISTER_OP("trilu")
.describe(
R"code(Filters out the upper or lower portion of an input tensor on one side of a diagonal.
)code" TVM_ADD_FILELINE)
.set_num_inputs(2)
.add_argument("data", "Tensor", "The input tensor")
.add_argument("k", "Tensor", "The number of diagonals above or below the main to exclude.")
.add_type_rel("trilu", TriluRel)
.set_support_level(3)
.set_attr<TOpPattern>("TOpPattern", kElemWise);

} // namespace relay
} // namespace tvm
16 changes: 0 additions & 16 deletions tests/python/frontend/onnx/test_forward.py
Original file line number Diff line number Diff line change
Expand Up @@ -5241,23 +5241,7 @@ def verify_eyelike(indata, dynamic=False):
"test_training_dropout_mask",
"test_training_dropout_zero_ratio",
"test_training_dropout_zero_ratio_mask",
"test_tril",
"test_tril_pos",
"test_tril_square",
"test_tril_square_neg",
"test_tril_neg",
"test_tril_one_row_neg",
"test_tril_out_neg",
"test_tril_out_pos",
"test_tril_zero",
"test_triu",
"test_triu_one_row",
"test_triu_out_neg_out",
"test_triu_out_pos",
"test_triu_neg",
"test_triu_pos",
"test_triu_square",
"test_triu_square_neg",
"test_triu_zero",
"test_unique_sorted_with_axis",
"test_unique_sorted_with_axis_3d",
Expand Down
29 changes: 29 additions & 0 deletions tests/python/relay/test_op_level3.py
Original file line number Diff line number Diff line change
Expand Up @@ -2207,5 +2207,34 @@ def test_stft(
)


def test_trilu(target="llvm", dev=tvm.cpu()):
def verify_trilu(data_shape, upper=True, k=0):
data = relay.var("data", relay.TensorType(data_shape, "float32"))
y = relay.trilu(data, k, upper)
mod = tvm.ir.IRModule.from_expr(y)

data_np = np.random.normal(size=data_shape).astype("float32")
tvm_res = (
relay.create_executor("graph", mod=mod, device=dev, target=target)
.evaluate()(data_np)
.numpy()
)
if upper:
np_res = np.triu(data_np, k)
else:
np_res = np.tril(data_np, k)
tvm.testing.assert_allclose(tvm_res, np_res)

# Test upper and lower triangle
verify_trilu((3, 3), True, 0)
verify_trilu((3, 3), False, 0)
# Test larger matrices with offset.
verify_trilu((6, 6), True, 1)
verify_trilu((6, 6), False, 2)
verify_trilu((6, 6), False, -2)
# Test batch size
verify_trilu((8, 6, 6), False, -2)


if __name__ == "__main__":
tvm.testing.main()
6 changes: 3 additions & 3 deletions tests/python/topi/python/test_topi_transform.py
Original file line number Diff line number Diff line change
Expand Up @@ -814,7 +814,7 @@ def check_device(target, dev):

def verify_trilu(input_shape, upper, k=0):
x = te.placeholder(shape=input_shape, name="x", dtype="float32")
trilu_result = topi.transform.trilu(x, upper, k)
trilu_result = topi.transform.trilu(x, k, upper)

def check_device(target, dev):
print("Running on target: %s" % target)
Expand Down Expand Up @@ -885,10 +885,10 @@ def test_reinterpret():
(1000,), "int16", "uint16", lambda shape: np.random.randint(-1000, 1000, size=shape)
)
verify_reinterpret(
(1000,), "uint32", "int32", lambda shape: np.random.randint(0, 2 ** 32 - 1, size=shape)
(1000,), "uint32", "int32", lambda shape: np.random.randint(0, 2**32 - 1, size=shape)
)
verify_reinterpret(
(1000,), "uint32", "int32", lambda shape: np.random.randint(0, 2 ** 32 - 1, size=shape)
(1000,), "uint32", "int32", lambda shape: np.random.randint(0, 2**32 - 1, size=shape)
)


Expand Down

0 comments on commit 698e978

Please sign in to comment.