Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[TOPI][ONNX] Fix for trilu and set_matrix_diag ops #11761

Open
wants to merge 2 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 0 additions & 4 deletions include/tvm/relay/attrs/transform.h
Original file line number Diff line number Diff line change
Expand Up @@ -482,14 +482,10 @@ struct OneHotAttrs : public tvm::AttrsNode<OneHotAttrs> {

/*! \brief Attributes used in matrix_set_diag operator */
struct MatrixSetDiagAttrs : public tvm::AttrsNode<MatrixSetDiagAttrs> {
int k1;
int k2;
bool super_diag_right_align;
bool sub_diag_right_align;

TVM_DECLARE_ATTRS(MatrixSetDiagAttrs, "relay.attrs.MatrixSetDiagAttrs") {
TVM_ATTR_FIELD(k1).set_default(0).describe("Lower limit (included) of the range of diagonals.");
TVM_ATTR_FIELD(k2).set_default(0).describe("Upper limit (included) of the range of diagonals.");
TVM_ATTR_FIELD(super_diag_right_align)
.set_default(true)
.describe("Bool, true iff super-diagonal is right aligned (left-padded).");
Expand Down
28 changes: 17 additions & 11 deletions include/tvm/topi/transform.h
Original file line number Diff line number Diff line change
Expand Up @@ -244,7 +244,7 @@ inline Tensor transpose(const Tensor& x, Array<Integer> axes, std::string name =
* \param x The input tensor
* \param seq_lengths A 1D Tensor with length x.dims[batch_axis]. Optional Tensor() can be passed.
* If not defined batch axis is ignored and tensor is reversed along seq_axis.
* \param seq_axis The axis along which the elements will be reveresed
* \param seq_axis The axis along which the elements will be reversed
* \param batch_axis The axis along which the tensor will be sliced
* \param name The name of the operation
* \param tag The tag to mark the operation
Expand All @@ -267,7 +267,7 @@ inline Tensor reverse_sequence(const Tensor& x, const Tensor& seq_lengths, int s
ICHECK(seq_lengths_dim == 1) << "seq_lengths should be 1D vector";

ICHECK(GetConstInt(seq_lengths->shape[0]) == GetConstInt(x->shape[batch_axis]))
<< "For reverse_sequnece seq_lengths size should match with dimension of batch axis"
<< "For reverse_sequence seq_lengths size should match with dimension of batch axis"
<< ", but got dimension of batch_axis = " << GetConstInt(x->shape[batch_axis])
<< ", and seq_length size = " << GetConstInt(seq_lengths->shape[0]);

Expand Down Expand Up @@ -763,7 +763,7 @@ inline Array<PrimExpr> StridedSliceOutputShape(
* \param name The name of the operation
* \param tag The tag to mark the operation
*
* \return A Tensor whose op member is the sstrided_slice operation
* \return A Tensor whose op member is the strided_slice operation
*/
inline Tensor strided_slice_with_axes(const Tensor& x, const Array<Integer>& begin,
const Array<Integer>& end, const Array<Integer>& strides,
Expand Down Expand Up @@ -1744,7 +1744,7 @@ inline Tensor ndarray_size(const Tensor& src, const DataType& dtype,
}

/*!
* \brief Returns a one-hot tensor where the locations repsented by indices take value on_value,
* \brief Returns a one-hot tensor where the locations represented by indices take value on_value,
other locations take value off_value.
* \param indices locations to set to on_value.
* \param on_value value that locations represented by indices take on.
Expand Down Expand Up @@ -1855,14 +1855,18 @@ inline Tensor sparse_to_dense(const Tensor& sparse_indices, const Array<PrimExpr
* \param tag output tensor tag.
* \return new tensor with given diagonal values.
*/
inline Tensor matrix_set_diag(const Tensor& input, const Tensor& diagonal, int k1, int k2,
bool super_diag_right_align, bool sub_diag_right_align,
inline Tensor matrix_set_diag(const Tensor& input, const Tensor& diagonal, const Tensor& k1,
const Tensor& k2, bool super_diag_right_align,
bool sub_diag_right_align,
const std::string name = "T_matrix_set_diag",
const std::string tag = kInjective) {
size_t ndim = input->shape.size() - 1;

bool only_one_diagonal = k1 == k2;

std::cout << "\n input " << input->GetShape() << "\n"
<< "diagonal " << diagonal << "\n k1 " << k1 << " \n k2 " << k2
<< "\n bool : " << only_one_diagonal;

return compute(
input->shape,
[&](const Array<Var>& iter_vars) {
Expand All @@ -1873,11 +1877,12 @@ inline Tensor matrix_set_diag(const Tensor& input, const Tensor& diagonal, int k
diagonal_indices.push_back(iter_vars[i]);
}
if (only_one_diagonal) {
k = k1;
k = k1(0);
} else {
// Determining which diagonal/sub-diagonal/super-diagonal it is
k = iter_vars[ndim] - iter_vars[ndim - 1];
diagonal_indices.push_back(k2 - k);
auto idx = k2(0) - k;
diagonal_indices.push_back(idx);

// Calculating the offset in diagonal tensor for this diagonal
auto get_offset = [&](PrimExpr M, PrimExpr N) {
Expand All @@ -1895,8 +1900,9 @@ inline Tensor matrix_set_diag(const Tensor& input, const Tensor& diagonal, int k
offset);
return diagonal(diagonal_indices);
};
return if_then_else((PrimExpr)iter_vars[ndim] - iter_vars[ndim - 1] >= k1,
if_then_else((PrimExpr)iter_vars[ndim] - iter_vars[ndim - 1] <= k2,

return if_then_else((PrimExpr)iter_vars[ndim] - iter_vars[ndim - 1] >= k1(0),
if_then_else((PrimExpr)iter_vars[ndim] - iter_vars[ndim - 1] <= k2(0),
get_diag(), input(iter_vars)),
input(iter_vars));
},
Expand Down
76 changes: 58 additions & 18 deletions python/tvm/relay/frontend/onnx.py
Original file line number Diff line number Diff line change
Expand Up @@ -3487,7 +3487,8 @@ def get_var(name, val, scan=False):

loop_vars = [
_expr.var(body.input[0].name, shape=(), dtype=iter_dtype), # iteration count
_expr.var("max_count", shape=(), dtype=iter_dtype), # iteration count
# iteration count
_expr.var("max_count", shape=(), dtype=iter_dtype),
get_var(body.input[1].name, cond), # exit condition
]
loop_vars += [get_var(body.input[i + 2].name, v) for i, v in enumerate(loop_deps)]
Expand Down Expand Up @@ -4230,9 +4231,9 @@ def _impl_v10(cls, inputs, attr, params):

dtype = infer_type(a).checked_type.dtype

## Onnxruntime doesn't actually do this op in integer, they dequantize to fp32
## and then requantize afer
## https://github.com/microsoft/onnxruntime/blob/master/onnxruntime/core/mlas/lib/qladd.cpp
# Onnxruntime doesn't actually do this op in integer, they dequantize to fp32
# and then requantize afer
# https://github.com/microsoft/onnxruntime/blob/master/onnxruntime/core/mlas/lib/qladd.cpp
a = _qnn.op.dequantize(
inputs[0], a_scale, a_zero_point
) # , c_scale, c_zero_point, out_dtype = dtype)
Expand Down Expand Up @@ -4296,7 +4297,8 @@ def try_resolve_to_const(x, dtype_override=None):
b_zp_type = infer_type(b_zp).checked_type

y_scale_type = infer_type(y_scale).checked_type
y_zp_type = infer_type(y_zp).checked_type # 'T3' in ONNX doc for this op
# 'T3' in ONNX doc for this op
y_zp_type = infer_type(y_zp).checked_type

a_shape = infer_shape(a)
b_shape = infer_shape(b)
Expand Down Expand Up @@ -4471,9 +4473,9 @@ def _impl_v10(cls, inputs, attr, params):

dtype = infer_type(a).checked_type.dtype

## Onnxruntime doesn't actually do this op in integer, they dequantize to fp32
## and then requantize afer
## https://github.com/microsoft/onnxruntime/blob/master/onnxruntime/core/mlas/lib/qlmul.cpp
# Onnxruntime doesn't actually do this op in integer, they dequantize to fp32
# and then requantize afer
# https://github.com/microsoft/onnxruntime/blob/master/onnxruntime/core/mlas/lib/qlmul.cpp
a = _qnn.op.dequantize(inputs[0], a_scale, a_zero_point)
b = _qnn.op.dequantize(inputs[3], b_scale, b_zero_point)
out = _op.multiply(a, b)
Expand Down Expand Up @@ -4515,10 +4517,10 @@ def _impl_v10(cls, inputs, attr, params):

dtype = infer_type(x).checked_type.dtype

## Apparently, onnxruntime doesn't do this op in integer, they dequantize to fp32
## and then requantize after:
## https://github.com/microsoft/onnxruntime/blob/master/onnxruntime/core/
## providers/dml/DmlExecutionProvider/src/GraphTransformer.cpp#L245
# Apparently, onnxruntime doesn't do this op in integer, they dequantize to fp32
# and then requantize after:
# https://github.com/microsoft/onnxruntime/blob/master/onnxruntime/core/
# providers/dml/DmlExecutionProvider/src/GraphTransformer.cpp#L245
x = _qnn.op.dequantize(x, x_scale, x_zero_point)
out = _op.sigmoid(x)
return _qnn.op.quantize(out, y_scale, y_zero_point, out_dtype=dtype)
Expand Down Expand Up @@ -4663,12 +4665,16 @@ def _impl_v11(cls, inputs, attr, params):
unique = _op.unique(data, is_sorted=(is_sorted == 1), return_counts=True)
num_unique = unique[3]

trim_unique_lambda = lambda input: _op.strided_slice(input, _op.const([0]), num_unique)
def trim_unique_lambda(input):
return _op.strided_slice(input, _op.const([0]), num_unique)

unique_vals = trim_unique_lambda(unique[0])
indices = _op.cast(trim_unique_lambda(unique[1]), "int64") # ONNX always returns int64
inverse_indices = _op.cast(unique[2], "int64") # ONNX always returns int64
counts = _op.cast(trim_unique_lambda(unique[4]), "int64") # ONNX always returns int64
# ONNX always returns int64
indices = _op.cast(trim_unique_lambda(unique[1]), "int64")
# ONNX always returns int64
inverse_indices = _op.cast(unique[2], "int64")
# ONNX always returns int64
counts = _op.cast(trim_unique_lambda(unique[4]), "int64")
# ONNX unique returns unique, indices, inverse_indices, (optional) counts
return _expr.TupleWrapper(_expr.Tuple([unique_vals, indices, inverse_indices, counts]), 4)

Expand Down Expand Up @@ -5087,6 +5093,37 @@ def _impl_v1(cls, inputs, attr, params):
return _expr.TupleWrapper(_expr.Tuple(result), len(result))


class Trilu(OnnxOpConverter):
"""Operator converter for Trilu"""

@classmethod
def _impl_v14(cls, inputs, attr, params):
upper = attr.get("upper", 1)
input_shape = shape_of(inputs[0])
input_dims = infer_shape(input_shape)[0]
data_type = infer_type(inputs[0]).checked_type.dtype
k_tensor = relay.const(np.asarray(0), dtype=np.int64)
if len(inputs) == 2:
k_tensor = inputs[1]

diag_input = relay.zeros(fold_constant(input_shape), dtype=data_type)
k1, k2 = None, None
if upper == 0:
k1 = relay.add(k_tensor, relay.const(1, dtype="int64"))
k1 = relay.expand_dims(k1, axis=0)
k2 = relay.take(input_shape, relay.const(input_dims - 1, dtype="int32"))
k2 = relay.expand_dims(k2, axis=0)
else:
k1 = relay.take(input_shape, relay.const(input_dims - 2, dtype="int32"))
k1 = relay.multiply(k1, relay.const(-1, dtype="int64"))
k1 = relay.subtract(k1, relay.const(1, dtype="int64"))
k1 = relay.expand_dims(k1, axis=0)
k2 = relay.subtract(k_tensor, relay.const(1, dtype="int64"))
k2 = relay.expand_dims(k2, axis=0)

return relay.matrix_set_diag(inputs[0], diag_input, k=(k1, k2))


class Round(OnnxOpConverter):
"""Operator converter for round op."""

Expand Down Expand Up @@ -5114,6 +5151,8 @@ def _impl_v11(cls, inputs, attr, params):
# use AttrCvt if attributes need to be converted
# for 1 to N mapping(composed), use custom callable functions
# for N to 1 mapping, currently not supported(?)


def _get_convert_map(opset):
return {
# defs/experimental
Expand Down Expand Up @@ -5287,6 +5326,7 @@ def _get_convert_map(opset):
"CumSum": CumSum.get_converter(opset),
"Unique": Unique.get_converter(opset),
"Einsum": Einsum.get_converter(opset),
"Trilu": Trilu.get_converter(opset),
# defs/control_flow
"Loop": Loop.get_converter(opset),
"If": If.get_converter(opset),
Expand Down Expand Up @@ -5420,8 +5460,8 @@ def from_onnx(self, graph, opset, get_output_expr=False):
# If requested, directly return the converted expressions.
if get_output_expr:
return outputs
## Maintain the order of inputs and parameters from the ONNX graph, but only include
## those parameters that are needed to execute the relay graph
# Maintain the order of inputs and parameters from the ONNX graph, but only include
# those parameters that are needed to execute the relay graph
free_vars = analysis.free_vars(outputs)
nodes = {v: k for k, v in self._nodes.items()}
free_vars = [nodes[var] for var in free_vars]
Expand Down
18 changes: 13 additions & 5 deletions python/tvm/relay/frontend/tflite.py
Original file line number Diff line number Diff line change
Expand Up @@ -535,7 +535,8 @@ def convert_qnn_fused_activation_function(
raise ImportError("The tflite package must be installed")

# Quantize a float value to an quantized integer value
quantize = lambda x: float(int(round(x / scale)) + zero_point)
def quantize(x):
return float(int(round(x / scale)) + zero_point)

# Get min/max of the output dtype. This will be used to ensure that clip a_min/a_max are not
# beyond the dtype range.
Expand Down Expand Up @@ -1060,7 +1061,9 @@ def convert_relu_n1_to_1(self, op):
# Quantize a float value to an quantized integer value
scale_val = get_scalar_from_constant(input_tensor.qnn_params["scale"])
zero_point_val = get_scalar_from_constant(input_tensor.qnn_params["zero_point"])
quantize = lambda x: float(int(round(x / scale_val)) + zero_point_val)

def quantize(x):
return float(int(round(x / scale_val)) + zero_point_val)

# Get min/max of the input dtype. This will be used to ensure that
# clip a_min/a_max are not beyond the dtype range.
Expand Down Expand Up @@ -3468,6 +3471,11 @@ def convert_matrix_set_diag(self, op):

input_expr = self.get_tensor_expr(input_tensors[0])
diagonal_expr = self.get_tensor_expr(input_tensors[1])
diag_shape = to_int_list(self.get_tensor_shape(input_tensors[1]))
input_shape = to_int_list(self.get_tensor_shape(input_tensors[0]))
if len(diag_shape) == len(input_shape) - 1:
diag_shape = np.insert(diag_shape, len(diag_shape) - 1, 1)
diagonal_expr = _op.reshape(diagonal_expr, diag_shape)

out = _op.matrix_set_diag(input_expr, diagonal_expr)
return out
Expand All @@ -3488,13 +3496,13 @@ def convert_matrix_diag(self, op):
scale and zero points to be equal"

shape = to_int_list(self.get_tensor_shape(diagonal))
shape = np.append(shape, shape[-1])
diag_shape = np.insert(shape, len(shape) - 1, 1).astype(np.int32)
dtype = self.get_tensor_type_str(diagonal.tensor.Type())

shape = np.append(shape, shape[-1]).astype(np.int32)
input_expr = _op.zeros(tuple(shape), dtype)
diagonal_expr = self.get_tensor_expr(diagonal)

out = _op.matrix_set_diag(input_expr, diagonal_expr)
out = _op.matrix_set_diag(input_expr, _op.reshape(diagonal_expr, diag_shape))
return out

def convert_densify(self, op):
Expand Down
8 changes: 7 additions & 1 deletion python/tvm/relay/op/transform.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
# pylint: disable=import-outside-toplevel
"""Transform operators."""

import numpy as np
from ...tir import expr as _expr
from ..expr import Constant, Expr, Tuple, TupleWrapper, const
from . import _make
Expand Down Expand Up @@ -1247,7 +1248,7 @@ def sequence_mask(data, valid_length, mask_value=0, axis=0):

def one_hot(indices, on_value, off_value, depth, axis, dtype):
"""
Returns a one-hot tensor where the locations repsented by indices take value on_value,
Returns a one-hot tensor where the locations represented by indices take value on_value,
other locations take value off_value.
Final dimension is <indices outer dimensions> x depth x <indices inner dimensions>.

Expand Down Expand Up @@ -1415,6 +1416,11 @@ def matrix_set_diag(data, diagonal, k=0, align="RIGHT_LEFT"):
k_one = k
k_two = k

if not isinstance(k_one, Expr):
k_one = const(np.asarray([k_one], dtype=np.int64))
if not isinstance(k_two, Expr):
k_two = const(np.asarray([k_two], dtype=np.int64))

super_diag_right_align = align[:5] == "RIGHT"
sub_diag_right_align = align[-5:] == "RIGHT"

Expand Down
9 changes: 6 additions & 3 deletions python/tvm/topi/transform.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,9 +17,12 @@
# pylint: disable=invalid-name,consider-using-enumerate,redefined-outer-name
"""Injective transformation operators"""
from __future__ import absolute_import as _abs
import numpy as np
from tables import Expr
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

are these two lines necessary? i am getting a tables not found on my end and it seems like those two libs aren't referenced anyway

Copy link
Contributor

@shingjan shingjan Jun 29, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I am hitting the following error:

E             0: tvm::codegen::CodeGenLLVM::CreateBufferPtr(llvm::Value*, tvm::runtime::DataType, llvm::ArrayRef<llvm::Value*>, tvm::runtime::DataType)
E                   at /home/yj/tvm/src/target/llvm/codegen_llvm.cc:737
E             File "/home/yj/tvm/src/target/llvm/codegen_llvm.cc", line 737
E           TVMError: 
E           ---------------------------------------------------------------
E           An error occurred during the execution of TVM.
E           For more information, please see: https://tvm.apache.org/docs/errors.html
E           ---------------------------------------------------------------
E             Check failed: (index->getType()->isIntegerTy()) is false: Expected buffer index to be an integer

I wonder if there are still places in the codegen that assumes k1,k2 as integer

Copy link
Contributor

@AndrewZhaoLuo AndrewZhaoLuo Jun 29, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think it's probably a spurious import (e.g. typed Expr and IDE autoimported this when it should be tvms)

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
from tables import Expr

import tvm
from tvm import te
from tvm import topi
from tvm.runtime.object_generic import const
from tvm.te import hybrid
from . import cpp
from . import tag
Expand Down Expand Up @@ -132,7 +135,7 @@ def flip(a, axis=0):
The tensor to be expanded.

axis : int, optional
The axis along which the tensors will be reveresed.
The axis along which the tensors will be reversed.

Returns
-------
Expand Down Expand Up @@ -183,7 +186,7 @@ def strided_slice(a, begin, end, strides=None, axes=None, slice_mode="end"):
The indices to begin with in the slicing.

end : list of int
Indicies indicating end of the slice.
Indices indicating end of the slice.

strides : list of int, optional
Specifies the stride values, it can be negative
Expand Down Expand Up @@ -757,7 +760,7 @@ def where(condition, x, y):

def one_hot(indices, on_value, off_value, depth, axis, dtype):
"""
Returns a one-hot tensor where the locations repsented by indices take value on_value,
Returns a one-hot tensor where the locations represented by indices take value on_value,
other locations take value off_value.
Final dimension is <indices outer dimensions> x depth x <indices inner dimensions>.

Expand Down
Loading