Skip to content

Commit

Permalink
[TOPI] Fix for trilu and set_matrix_diag
Browse files Browse the repository at this point in the history
  • Loading branch information
mikepapadim committed Jun 16, 2022
1 parent 3eb372e commit dd0ba5a
Show file tree
Hide file tree
Showing 11 changed files with 242 additions and 98 deletions.
4 changes: 0 additions & 4 deletions include/tvm/relay/attrs/transform.h
Original file line number Diff line number Diff line change
Expand Up @@ -482,14 +482,10 @@ struct OneHotAttrs : public tvm::AttrsNode<OneHotAttrs> {

/*! \brief Attributes used in matrix_set_diag operator */
struct MatrixSetDiagAttrs : public tvm::AttrsNode<MatrixSetDiagAttrs> {
int k1;
int k2;
bool super_diag_right_align;
bool sub_diag_right_align;

TVM_DECLARE_ATTRS(MatrixSetDiagAttrs, "relay.attrs.MatrixSetDiagAttrs") {
TVM_ATTR_FIELD(k1).set_default(0).describe("Lower limit (included) of the range of diagonals.");
TVM_ATTR_FIELD(k2).set_default(0).describe("Upper limit (included) of the range of diagonals.");
TVM_ATTR_FIELD(super_diag_right_align)
.set_default(true)
.describe("Bool, true iff super-diagonal is right aligned (left-padded).");
Expand Down
28 changes: 17 additions & 11 deletions include/tvm/topi/transform.h
Original file line number Diff line number Diff line change
Expand Up @@ -244,7 +244,7 @@ inline Tensor transpose(const Tensor& x, Array<Integer> axes, std::string name =
* \param x The input tensor
* \param seq_lengths A 1D Tensor with length x.dims[batch_axis]. Optional Tensor() can be passed.
* If not defined batch axis is ignored and tensor is reversed along seq_axis.
* \param seq_axis The axis along which the elements will be reveresed
* \param seq_axis The axis along which the elements will be reversed
* \param batch_axis The axis along which the tensor will be sliced
* \param name The name of the operation
* \param tag The tag to mark the operation
Expand All @@ -267,7 +267,7 @@ inline Tensor reverse_sequence(const Tensor& x, const Tensor& seq_lengths, int s
ICHECK(seq_lengths_dim == 1) << "seq_lengths should be 1D vector";

ICHECK(GetConstInt(seq_lengths->shape[0]) == GetConstInt(x->shape[batch_axis]))
<< "For reverse_sequnece seq_lengths size should match with dimension of batch axis"
<< "For reverse_sequence seq_lengths size should match with dimension of batch axis"
<< ", but got dimension of batch_axis = " << GetConstInt(x->shape[batch_axis])
<< ", and seq_length size = " << GetConstInt(seq_lengths->shape[0]);

Expand Down Expand Up @@ -763,7 +763,7 @@ inline Array<PrimExpr> StridedSliceOutputShape(
* \param name The name of the operation
* \param tag The tag to mark the operation
*
* \return A Tensor whose op member is the sstrided_slice operation
* \return A Tensor whose op member is the strided_slice operation
*/
inline Tensor strided_slice_with_axes(const Tensor& x, const Array<Integer>& begin,
const Array<Integer>& end, const Array<Integer>& strides,
Expand Down Expand Up @@ -1744,7 +1744,7 @@ inline Tensor ndarray_size(const Tensor& src, const DataType& dtype,
}

/*!
* \brief Returns a one-hot tensor where the locations repsented by indices take value on_value,
* \brief Returns a one-hot tensor where the locations represented by indices take value on_value,
other locations take value off_value.
* \param indices locations to set to on_value.
* \param on_value value that locations represented by indices take on.
Expand Down Expand Up @@ -1855,14 +1855,18 @@ inline Tensor sparse_to_dense(const Tensor& sparse_indices, const Array<PrimExpr
* \param tag output tensor tag.
* \return new tensor with given diagonal values.
*/
inline Tensor matrix_set_diag(const Tensor& input, const Tensor& diagonal, int k1, int k2,
bool super_diag_right_align, bool sub_diag_right_align,
inline Tensor matrix_set_diag(const Tensor& input, const Tensor& diagonal, const Tensor& k1,
const Tensor& k2, bool super_diag_right_align,
bool sub_diag_right_align,
const std::string name = "T_matrix_set_diag",
const std::string tag = kInjective) {
size_t ndim = input->shape.size() - 1;

bool only_one_diagonal = k1 == k2;

std::cout << "\n input " << input->GetShape() << "\n"
<< "diagonal " << diagonal << "\n k1 " << k1 << " \n k2 " << k2
<< "\n bool : " << only_one_diagonal;

return compute(
input->shape,
[&](const Array<Var>& iter_vars) {
Expand All @@ -1873,11 +1877,12 @@ inline Tensor matrix_set_diag(const Tensor& input, const Tensor& diagonal, int k
diagonal_indices.push_back(iter_vars[i]);
}
if (only_one_diagonal) {
k = k1;
k = k1(0);
} else {
// Determining which diagonal/sub-diagonal/super-diagonal it is
k = iter_vars[ndim] - iter_vars[ndim - 1];
diagonal_indices.push_back(k2 - k);
auto idx = k2(0) - k;
diagonal_indices.push_back(idx);

// Calculating the offset in diagonal tensor for this diagonal
auto get_offset = [&](PrimExpr M, PrimExpr N) {
Expand All @@ -1895,8 +1900,9 @@ inline Tensor matrix_set_diag(const Tensor& input, const Tensor& diagonal, int k
offset);
return diagonal(diagonal_indices);
};
return if_then_else((PrimExpr)iter_vars[ndim] - iter_vars[ndim - 1] >= k1,
if_then_else((PrimExpr)iter_vars[ndim] - iter_vars[ndim - 1] <= k2,

return if_then_else((PrimExpr)iter_vars[ndim] - iter_vars[ndim - 1] >= k1(0),
if_then_else((PrimExpr)iter_vars[ndim] - iter_vars[ndim - 1] <= k2(0),
get_diag(), input(iter_vars)),
input(iter_vars));
},
Expand Down
76 changes: 58 additions & 18 deletions python/tvm/relay/frontend/onnx.py
Original file line number Diff line number Diff line change
Expand Up @@ -3487,7 +3487,8 @@ def get_var(name, val, scan=False):

loop_vars = [
_expr.var(body.input[0].name, shape=(), dtype=iter_dtype), # iteration count
_expr.var("max_count", shape=(), dtype=iter_dtype), # iteration count
# iteration count
_expr.var("max_count", shape=(), dtype=iter_dtype),
get_var(body.input[1].name, cond), # exit condition
]
loop_vars += [get_var(body.input[i + 2].name, v) for i, v in enumerate(loop_deps)]
Expand Down Expand Up @@ -4230,9 +4231,9 @@ def _impl_v10(cls, inputs, attr, params):

dtype = infer_type(a).checked_type.dtype

## Onnxruntime doesn't actually do this op in integer, they dequantize to fp32
## and then requantize afer
## https://github.com/microsoft/onnxruntime/blob/master/onnxruntime/core/mlas/lib/qladd.cpp
# Onnxruntime doesn't actually do this op in integer, they dequantize to fp32
# and then requantize afer
# https://github.com/microsoft/onnxruntime/blob/master/onnxruntime/core/mlas/lib/qladd.cpp
a = _qnn.op.dequantize(
inputs[0], a_scale, a_zero_point
) # , c_scale, c_zero_point, out_dtype = dtype)
Expand Down Expand Up @@ -4296,7 +4297,8 @@ def try_resolve_to_const(x, dtype_override=None):
b_zp_type = infer_type(b_zp).checked_type

y_scale_type = infer_type(y_scale).checked_type
y_zp_type = infer_type(y_zp).checked_type # 'T3' in ONNX doc for this op
# 'T3' in ONNX doc for this op
y_zp_type = infer_type(y_zp).checked_type

a_shape = infer_shape(a)
b_shape = infer_shape(b)
Expand Down Expand Up @@ -4471,9 +4473,9 @@ def _impl_v10(cls, inputs, attr, params):

dtype = infer_type(a).checked_type.dtype

## Onnxruntime doesn't actually do this op in integer, they dequantize to fp32
## and then requantize afer
## https://github.com/microsoft/onnxruntime/blob/master/onnxruntime/core/mlas/lib/qlmul.cpp
# Onnxruntime doesn't actually do this op in integer, they dequantize to fp32
# and then requantize afer
# https://github.com/microsoft/onnxruntime/blob/master/onnxruntime/core/mlas/lib/qlmul.cpp
a = _qnn.op.dequantize(inputs[0], a_scale, a_zero_point)
b = _qnn.op.dequantize(inputs[3], b_scale, b_zero_point)
out = _op.multiply(a, b)
Expand Down Expand Up @@ -4515,10 +4517,10 @@ def _impl_v10(cls, inputs, attr, params):

dtype = infer_type(x).checked_type.dtype

## Apparently, onnxruntime doesn't do this op in integer, they dequantize to fp32
## and then requantize after:
## https://github.com/microsoft/onnxruntime/blob/master/onnxruntime/core/
## providers/dml/DmlExecutionProvider/src/GraphTransformer.cpp#L245
# Apparently, onnxruntime doesn't do this op in integer, they dequantize to fp32
# and then requantize after:
# https://github.com/microsoft/onnxruntime/blob/master/onnxruntime/core/
# providers/dml/DmlExecutionProvider/src/GraphTransformer.cpp#L245
x = _qnn.op.dequantize(x, x_scale, x_zero_point)
out = _op.sigmoid(x)
return _qnn.op.quantize(out, y_scale, y_zero_point, out_dtype=dtype)
Expand Down Expand Up @@ -4663,12 +4665,16 @@ def _impl_v11(cls, inputs, attr, params):
unique = _op.unique(data, is_sorted=(is_sorted == 1), return_counts=True)
num_unique = unique[3]

trim_unique_lambda = lambda input: _op.strided_slice(input, _op.const([0]), num_unique)
def trim_unique_lambda(input):
return _op.strided_slice(input, _op.const([0]), num_unique)

unique_vals = trim_unique_lambda(unique[0])
indices = _op.cast(trim_unique_lambda(unique[1]), "int64") # ONNX always returns int64
inverse_indices = _op.cast(unique[2], "int64") # ONNX always returns int64
counts = _op.cast(trim_unique_lambda(unique[4]), "int64") # ONNX always returns int64
# ONNX always returns int64
indices = _op.cast(trim_unique_lambda(unique[1]), "int64")
# ONNX always returns int64
inverse_indices = _op.cast(unique[2], "int64")
# ONNX always returns int64
counts = _op.cast(trim_unique_lambda(unique[4]), "int64")
# ONNX unique returns unique, indices, inverse_indices, (optional) counts
return _expr.TupleWrapper(_expr.Tuple([unique_vals, indices, inverse_indices, counts]), 4)

Expand Down Expand Up @@ -5087,6 +5093,37 @@ def _impl_v1(cls, inputs, attr, params):
return _expr.TupleWrapper(_expr.Tuple(result), len(result))


class Trilu(OnnxOpConverter):
"""Operator converter for Trilu"""

@classmethod
def _impl_v14(cls, inputs, attr, params):
upper = attr.get("upper", 1)
input_shape = shape_of(inputs[0])
input_dims = infer_shape(input_shape)[0]
data_type = infer_type(inputs[0]).checked_type.dtype
k_tensor = relay.const(np.asarray(0), dtype=np.int64)
if len(inputs) == 2:
k_tensor = inputs[1]

diag_input = relay.zeros(fold_constant(input_shape), dtype=data_type)
k1, k2 = None, None
if upper == 0:
k1 = relay.add(k_tensor, relay.const(1, dtype="int64"))
k1 = relay.expand_dims(k1, axis=0)
k2 = relay.take(input_shape, relay.const(input_dims - 1, dtype="int32"))
k2 = relay.expand_dims(k2, axis=0)
else:
k1 = relay.take(input_shape, relay.const(input_dims - 2, dtype="int32"))
k1 = relay.multiply(k1, relay.const(-1, dtype="int64"))
k1 = relay.subtract(k1, relay.const(1, dtype="int64"))
k1 = relay.expand_dims(k1, axis=0)
k2 = relay.subtract(k_tensor, relay.const(1, dtype="int64"))
k2 = relay.expand_dims(k2, axis=0)

return relay.matrix_set_diag(inputs[0], diag_input, k=(k1, k2))


class Round(OnnxOpConverter):
"""Operator converter for round op."""

Expand Down Expand Up @@ -5114,6 +5151,8 @@ def _impl_v11(cls, inputs, attr, params):
# use AttrCvt if attributes need to be converted
# for 1 to N mapping(composed), use custom callable functions
# for N to 1 mapping, currently not supported(?)


def _get_convert_map(opset):
return {
# defs/experimental
Expand Down Expand Up @@ -5287,6 +5326,7 @@ def _get_convert_map(opset):
"CumSum": CumSum.get_converter(opset),
"Unique": Unique.get_converter(opset),
"Einsum": Einsum.get_converter(opset),
"Trilu": Trilu.get_converter(opset),
# defs/control_flow
"Loop": Loop.get_converter(opset),
"If": If.get_converter(opset),
Expand Down Expand Up @@ -5420,8 +5460,8 @@ def from_onnx(self, graph, opset, get_output_expr=False):
# If requested, directly return the converted expressions.
if get_output_expr:
return outputs
## Maintain the order of inputs and parameters from the ONNX graph, but only include
## those parameters that are needed to execute the relay graph
# Maintain the order of inputs and parameters from the ONNX graph, but only include
# those parameters that are needed to execute the relay graph
free_vars = analysis.free_vars(outputs)
nodes = {v: k for k, v in self._nodes.items()}
free_vars = [nodes[var] for var in free_vars]
Expand Down
18 changes: 13 additions & 5 deletions python/tvm/relay/frontend/tflite.py
Original file line number Diff line number Diff line change
Expand Up @@ -535,7 +535,8 @@ def convert_qnn_fused_activation_function(
raise ImportError("The tflite package must be installed")

# Quantize a float value to an quantized integer value
quantize = lambda x: float(int(round(x / scale)) + zero_point)
def quantize(x):
return float(int(round(x / scale)) + zero_point)

# Get min/max of the output dtype. This will be used to ensure that clip a_min/a_max are not
# beyond the dtype range.
Expand Down Expand Up @@ -1060,7 +1061,9 @@ def convert_relu_n1_to_1(self, op):
# Quantize a float value to an quantized integer value
scale_val = get_scalar_from_constant(input_tensor.qnn_params["scale"])
zero_point_val = get_scalar_from_constant(input_tensor.qnn_params["zero_point"])
quantize = lambda x: float(int(round(x / scale_val)) + zero_point_val)

def quantize(x):
return float(int(round(x / scale_val)) + zero_point_val)

# Get min/max of the input dtype. This will be used to ensure that
# clip a_min/a_max are not beyond the dtype range.
Expand Down Expand Up @@ -3468,6 +3471,11 @@ def convert_matrix_set_diag(self, op):

input_expr = self.get_tensor_expr(input_tensors[0])
diagonal_expr = self.get_tensor_expr(input_tensors[1])
diag_shape = to_int_list(self.get_tensor_shape(input_tensors[1]))
input_shape = to_int_list(self.get_tensor_shape(input_tensors[0]))
if len(diag_shape) == len(input_shape) - 1:
diag_shape = np.insert(diag_shape, len(diag_shape) - 1, 1)
diagonal_expr = _op.reshape(diagonal_expr, diag_shape)

out = _op.matrix_set_diag(input_expr, diagonal_expr)
return out
Expand All @@ -3488,13 +3496,13 @@ def convert_matrix_diag(self, op):
scale and zero points to be equal"

shape = to_int_list(self.get_tensor_shape(diagonal))
shape = np.append(shape, shape[-1])
diag_shape = np.insert(shape, len(shape) - 1, 1).astype(np.int32)
dtype = self.get_tensor_type_str(diagonal.tensor.Type())

shape = np.append(shape, shape[-1]).astype(np.int32)
input_expr = _op.zeros(tuple(shape), dtype)
diagonal_expr = self.get_tensor_expr(diagonal)

out = _op.matrix_set_diag(input_expr, diagonal_expr)
out = _op.matrix_set_diag(input_expr, _op.reshape(diagonal_expr, diag_shape))
return out

def convert_densify(self, op):
Expand Down
10 changes: 9 additions & 1 deletion python/tvm/relay/op/transform.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
# pylint: disable=import-outside-toplevel
"""Transform operators."""

import numpy as np
from ...tir import expr as _expr
from ..expr import Constant, Expr, Tuple, TupleWrapper, const
from . import _make
Expand Down Expand Up @@ -1247,7 +1248,7 @@ def sequence_mask(data, valid_length, mask_value=0, axis=0):

def one_hot(indices, on_value, off_value, depth, axis, dtype):
"""
Returns a one-hot tensor where the locations repsented by indices take value on_value,
Returns a one-hot tensor where the locations represented by indices take value on_value,
other locations take value off_value.
Final dimension is <indices outer dimensions> x depth x <indices inner dimensions>.
Expand Down Expand Up @@ -1415,9 +1416,16 @@ def matrix_set_diag(data, diagonal, k=0, align="RIGHT_LEFT"):
k_one = k
k_two = k

if not isinstance(k_one, Expr):
k_one = const(np.asarray([k_one], dtype=np.int64))
if not isinstance(k_two, Expr):
k_two = const(np.asarray([k_two], dtype=np.int64))

super_diag_right_align = align[:5] == "RIGHT"
sub_diag_right_align = align[-5:] == "RIGHT"

k_one = const(0)
k_two = const(0)
return _make.matrix_set_diag(
data, diagonal, k_one, k_two, super_diag_right_align, sub_diag_right_align
)
Expand Down
Loading

0 comments on commit dd0ba5a

Please sign in to comment.