Skip to content

Commit

Permalink
[TOPI][ONNX] Fix for trilu and set_matrix_diag ops
Browse files Browse the repository at this point in the history
  • Loading branch information
mikepapadim committed Jun 17, 2022
1 parent dd0ba5a commit 057dd7c
Show file tree
Hide file tree
Showing 4 changed files with 3 additions and 25 deletions.
2 changes: 0 additions & 2 deletions python/tvm/relay/op/transform.py
Original file line number Diff line number Diff line change
Expand Up @@ -1424,8 +1424,6 @@ def matrix_set_diag(data, diagonal, k=0, align="RIGHT_LEFT"):
super_diag_right_align = align[:5] == "RIGHT"
sub_diag_right_align = align[-5:] == "RIGHT"

k_one = const(0)
k_two = const(0)
return _make.matrix_set_diag(
data, diagonal, k_one, k_two, super_diag_right_align, sub_diag_right_align
)
Expand Down
18 changes: 0 additions & 18 deletions python/tvm/topi/transform.py
Original file line number Diff line number Diff line change
Expand Up @@ -910,33 +910,15 @@ def matrix_set_diag(data, diagonal, k=0, align="RIGHT_LEFT"):
[7, 5, 7, 7],
[7, 7, 6, 7]]]
"""
print("\n")
print("\n")
print("\n")
if isinstance(k, (tuple, list)):
print("What is k 1 \n")
k_one = k[0]
if len(k) >= 2:
k_two = k[1]
else:
k_two = k[0]
else:
print("What is k 2 \n")
k_one = k
k_two = k
# k_one = te.placeholder(shape=(1,), name="k1", dtype="int64")
# k_two = te.placeholder(shape=(1,), name="k1", dtype="int64")

# if not isinstance(k_one, Expr):
# k_one = const(np.asarray([k_one], dtype=np.int64))
# if not isinstance(k_two, Expr):
# k_two = const(np.asarray([k_two], dtype=np.int64))

print(" one ", k_one)
print(" two ", k_two)

# k_one = te.placeholder(shape=(1,), name="k1", dtype="int64")
# k_two = k_one

super_diag_right_align = align[:5] == "RIGHT"
sub_diag_right_align = align[-5:] == "RIGHT"
Expand Down
3 changes: 0 additions & 3 deletions src/relay/op/tensor/transform.cc
Original file line number Diff line number Diff line change
Expand Up @@ -3905,9 +3905,6 @@ Array<te::Tensor> MatrixSetDiagCompute(const Attrs& attrs, const Array<te::Tenso
const Type& out_type) {
const auto* param = attrs.as<MatrixSetDiagAttrs>();
ICHECK(param != nullptr);
std::cout << "**** \n"
<< "d";
printf("*******************\n");
return Array<te::Tensor>{topi::matrix_set_diag(inputs[0], inputs[1], inputs[2], inputs[3],
param->super_diag_right_align,
param->sub_diag_right_align)};
Expand Down
5 changes: 3 additions & 2 deletions src/topi/transform.cc
Original file line number Diff line number Diff line change
Expand Up @@ -215,8 +215,9 @@ TVM_REGISTER_GLOBAL("topi.one_hot").set_body([](TVMArgs args, TVMRetValue* rv) {
TVM_REGISTER_GLOBAL("topi.matrix_set_diag").set_body([](TVMArgs args, TVMRetValue* rv) {
bool super_diag_right_align = args[4];
bool sub_diag_right_align = args[5];
*rv = matrix_set_diag(args[0], args[1], args[2], args[3], super_diag_right_align,
sub_diag_right_align);
Tensor k1 = args[2];
Tensor k2 = args[3];
*rv = matrix_set_diag(args[0], args[1], k1, k2, super_diag_right_align, sub_diag_right_align);
});

TVM_REGISTER_GLOBAL("topi.adv_index").set_body([](TVMArgs args, TVMRetValue* rv) {
Expand Down

0 comments on commit 057dd7c

Please sign in to comment.