diff --git a/python/tvm/script/parser/core/evaluator.py b/python/tvm/script/parser/core/evaluator.py index 96901c522d22..e2b67341dcb5 100644 --- a/python/tvm/script/parser/core/evaluator.py +++ b/python/tvm/script/parser/core/evaluator.py @@ -158,6 +158,44 @@ def _visit(self, node: doc.AST) -> Any: res : Any The evaluation result. """ + args = [] + if ( + isinstance(node, doc.Call) + and hasattr(node.func, "attr") + and node.func.attr not in ["reads", "writes", "match_buffer", "realize"] + ) or isinstance(node, (doc.BinOp, doc.UnaryOp, doc.Compare, doc.BoolOp)): + if isinstance(node, doc.BinOp): + args = [node.left, node.right] + elif isinstance(node, doc.UnaryOp): + args = [node.operand] + elif isinstance(node, doc.Compare): + args = [node.left, *node.comparators] + else: + if isinstance(node, doc.Call): + args = node.args + elif isinstance(node, doc.BoolOp): + args = node.values + for arg in args: + if isinstance(arg, doc.Subscript) and isinstance(arg.slice, (doc.Slice, doc.Tuple)): + if isinstance(arg.slice, doc.Slice): + check_slices = [arg.slice] + else: + check_slices = [] + for p in arg.slice.elts: + if isinstance(p, doc.Slice): + check_slices.append(p) + for s in check_slices: + if not s.step and s.upper and s.lower: + s.step = doc.Constant( + 1, + None, + 1, + 1, + s.upper.lineno, + s.upper.end_col_offset + 1, + s.upper.lineno, + s.upper.end_col_offset + 2, + ) if isinstance(node, list): return [self._visit(n) for n in node] if isinstance(node, tuple): diff --git a/python/tvm/script/parser/tir/operation.py b/python/tvm/script/parser/tir/operation.py index 3e120339a6e4..ab01d91c02e3 100644 --- a/python/tvm/script/parser/tir/operation.py +++ b/python/tvm/script/parser/tir/operation.py @@ -19,7 +19,9 @@ from typing import Type from tvm import tir +from tvm._ffi.runtime_ctypes import DataType, DataTypeCode from tvm.tir import IntImm +from tvm.tir.expr import FloatImm from .._core import OpMethod, doc, register_op @@ -32,14 +34,88 @@ def _and(a, b): a = IntImm("bool", a) if isinstance(b, bool): b = IntImm("bool", b) - return tir.And(a, b) + if DataType(a.dtype).lanes > 1 or DataType(b.dtype).lanes > 1: + return a & b + else: + return tir.And(a, b) def _or(a, b): if isinstance(a, bool): a = IntImm("bool", a) if isinstance(b, bool): b = IntImm("bool", b) - return tir.Or(a, b) + if DataType(a.dtype).lanes > 1 or DataType(b.dtype).lanes > 1: + return a | b + else: + return tir.Or(a, b) + + def _get_type_str(dtype: str): + if DataType(dtype).lanes == 1: + return dtype + index = dtype.find("x") + return dtype[0:index] + + def _auto_broadcast(a, b, op): + + if isinstance(a, int): + if hasattr(b, "dtype"): + if ( + DataType(b.dtype).type_code == DataTypeCode.INT + or DataType(b.dtype).type_code == DataTypeCode.UINT + ): + a = IntImm(_get_type_str(b.dtype), a) + elif DataType(b.dtype).type_code == DataTypeCode.FLOAT: + a = FloatImm(_get_type_str(b.dtype), a) + elif isinstance(b, float): + a = FloatImm("float32", a) + else: + a = IntImm("int32", a) + elif isinstance(a, float): + if DataType(b.dtype).type_code == DataTypeCode.FLOAT: + a = FloatImm(_get_type_str(b.dtype), a) + else: + a = FloatImm("float32", a) + + assert isinstance(a, tir.PrimExpr), "Operand should be a PrimExpr." + if isinstance(b, int): + if ( + DataType(a.dtype).type_code == DataTypeCode.INT + or DataType(a.dtype).type_code == DataTypeCode.UINT + ): + b = IntImm(_get_type_str(a.dtype), b) + elif DataType(a.dtype).type_code == DataTypeCode.FLOAT: + b = FloatImm(_get_type_str(a.dtype), b) + elif isinstance(b, float): + b = FloatImm(_get_type_str(a.dtype), b) + + if DataType(a.dtype).lanes == DataType(b.dtype).lanes: + return op(a, b) + elif DataType(a.dtype).lanes == 1 and DataType(a.dtype).lanes != DataType(b.dtype).lanes: + broadcast_a = tir.Broadcast(a, DataType(b.dtype).lanes) + return op(broadcast_a, b) + elif DataType(b.dtype).lanes == 1 and DataType(a.dtype).lanes != DataType(b.dtype).lanes: + broadcast_b = tir.Broadcast(b, DataType(a.dtype).lanes) + return op(a, broadcast_b) + else: + raise TypeError("do not know how to deal with it.") + + def _eq(a, b): + return _auto_broadcast(a, b, tir.EQ) + + def _ne(a, b): + return _auto_broadcast(a, b, tir.NE) + + def _lt(a, b): + return _auto_broadcast(a, b, tir.LT) + + def _le(a, b): + return _auto_broadcast(a, b, tir.LE) + + def _gt(a, b): + return _auto_broadcast(a, b, tir.GT) + + def _ge(a, b): + return _auto_broadcast(a, b, tir.GE) def r(op: Type, i: int, m: OpMethod): # pylint: disable=invalid-name register_op(ty, op, i)(m) @@ -60,12 +136,12 @@ def r(op: Type, i: int, m: OpMethod): # pylint: disable=invalid-name # doc.MatMult <-- not implemented # doc.Pow <-- not implemented # Case 2. cmpop - r(doc.Eq, i, tir.EQ) - r(doc.NotEq, i, tir.NE) - r(doc.Lt, i, tir.LT) - r(doc.LtE, i, tir.LE) - r(doc.Gt, i, tir.GT) - r(doc.GtE, i, tir.GE) + r(doc.Eq, i, _eq) + r(doc.NotEq, i, _ne) + r(doc.Lt, i, _lt) + r(doc.LtE, i, _le) + r(doc.Gt, i, _gt) + r(doc.GtE, i, _ge) # doc.Is <-- not implemented # doc.IsNot <-- not implemented # doc.In <-- not implemented diff --git a/python/tvm/script/parser/tir/parser.py b/python/tvm/script/parser/tir/parser.py index 7d81fecedbca..f81f9bd9ea78 100644 --- a/python/tvm/script/parser/tir/parser.py +++ b/python/tvm/script/parser/tir/parser.py @@ -211,6 +211,28 @@ def visit_assign(self: Parser, node: doc.Assign) -> None: if len(node.targets) != 1: self.report_error(node, "Consequential assignments like 'a = b = c' are not supported.") lhs = node.targets[0] + + if isinstance(node.value, doc.Subscript): + check_slices = [] + if isinstance(node.value.slice, doc.Slice): + check_slices = [node.value.slice] + elif isinstance(node.value.slice, doc.Tuple): + for p in node.value.slice.elts: + if isinstance(p, doc.Slice): + check_slices.append(p) + for s in check_slices: + if not s.step and s.upper and s.lower: + s.step = doc.Constant( + 1, + None, + 1, + 1, + s.upper.lineno, + s.upper.end_col_offset + 1, + s.upper.lineno, + s.upper.end_col_offset + 2, + ) + rhs = self.eval_expr(node.value) if isinstance(lhs, doc.Subscript): if isinstance(lhs.slice, doc.Tuple): diff --git a/python/tvm/tir/buffer.py b/python/tvm/tir/buffer.py index 764b8a3dd32a..ec57ad7801ca 100644 --- a/python/tvm/tir/buffer.py +++ b/python/tvm/tir/buffer.py @@ -203,11 +203,14 @@ def __getitem__(self, indices): return BufferRegion(self, region) else: expr_indices = [] - for index in indices: + for i, index in enumerate(indices): if isinstance(index, slice): start = 0 if index.start is None else index.start stop = self.shape[i] if index.stop is None else index.stop step = 1 if index.step is None else index.step + # We should ensure the dtype of start is the same with that of step. + if isinstance(start, tvm.tir.expr.PrimExpr) and isinstance(step, int): + step = tvm.tir.expr.IntImm(start.dtype, step) lanes = analyzer.simplify((stop - start + step - 1) // step) if lanes == 1: expr_indices.append(start) diff --git a/tests/python/unittest/test_tvmscript_ops.py b/tests/python/unittest/test_tvmscript_ops.py index 8eba301fe719..671fe3cc199d 100644 --- a/tests/python/unittest/test_tvmscript_ops.py +++ b/tests/python/unittest/test_tvmscript_ops.py @@ -177,6 +177,75 @@ def test_ceildiv(): tvm.testing.assert_allclose(a.numpy(), ref) +@T.prim_func +def slice_op_test( + A: T.Buffer((10,), "float32"), B: T.Buffer((10,), "float32"), C: T.Buffer((10,), "uint32") +): + B[0:5] = A[0:5] + B[0:5] + B[0:5] = A[0:5] - B[0:5] + B[0:5] = A[0:5] * B[0:5] + B[0:5] = A[0:5] / B[0:5] + C[0:5] = C[0:5] % T.broadcast(T.uint32(5), 5) + B[0:5] = -B[0:5] + C[0:5] = C[0:5] >> 4 + C[0:5] = C[0:5] << 4 + C[0:5] = C[0:5] << C[0:5] + C[0:5] = C[0:5] >> C[0:5] + T.evaluate(A[0:5] > B[0:5]) + T.evaluate(A[0:5] > 5) + T.evaluate(A[0:5] >= B[0:5]) + T.evaluate(A[0:5] >= 5) + T.evaluate(A[0:5] < B[0:5]) + T.evaluate(A[0:5] < 5) + T.evaluate(A[0:5] <= B[0:5]) + T.evaluate(A[0:5] <= 5) + T.evaluate(A[0:5] == B[0:5]) + T.evaluate(A[0:5] == 5) + T.evaluate(A[0:5] != B[0:5]) + T.evaluate(A[0:5] != 5) + T.evaluate((A[0:5] > 0) and (B[0:5] > 0)) + T.evaluate((A[0:5] > 0) or (B[0:5] > 0)) + T.evaluate((A[0:5] < 0) and (1 > 0)) + T.evaluate((A[0:5] > 0) or (1 > 0)) + + +@T.prim_func +def slice_op_test_ref( + A: T.Buffer((10,), "float32"), B: T.Buffer((10,), "float32"), C: T.Buffer((10,), "uint32") +): + B[0:5] = A[0:5] + B[0:5] + B[0:5] = A[0:5] - B[0:5] + B[0:5] = A[0:5] * B[0:5] + B[0:5] = A[0:5] / B[0:5] + C[0:5] = C[0:5] % T.Broadcast(T.uint32(5), 5) + B[0:5] = B[0:5] * T.Broadcast(T.float32(-1), 5) + C[0:5] = T.shift_right(C[0:5], T.Broadcast(T.uint32(4), 5)) + C[0:5] = T.shift_left(C[0:5], T.Broadcast(T.uint32(4), 5)) + C[0:5] = T.shift_left(C[0:5], C[0:5]) + C[0:5] = T.shift_right(C[0:5], C[0:5]) + T.evaluate(A[0:5] > B[0:5]) + T.evaluate(A[0:5] > T.Broadcast(T.float32(5), 5)) + T.evaluate(A[0:5] >= B[0:5]) + T.evaluate(A[0:5] >= T.Broadcast(T.float32(5), 5)) + T.evaluate(A[0:5] < B[0:5]) + T.evaluate(A[0:5] < T.Broadcast(T.float32(5), 5)) + T.evaluate(A[0:5] <= B[0:5]) + T.evaluate(A[0:5] <= T.Broadcast(T.float32(5), 5)) + T.evaluate(A[0:5] == B[0:5]) + T.evaluate(A[0:5] == T.Broadcast(T.float32(5), 5)) + T.evaluate(A[0:5] != B[0:5]) + T.evaluate(A[0:5] != T.Broadcast(T.float32(5), 5)) + T.bitwise_and(A[0:5] > T.Broadcast(T.float32(0), 5), B[0:5] > T.Broadcast(T.float32(0), 5)) + T.bitwise_or(A[0:5] > T.Broadcast(T.float32(0), 5), B[0:5] > T.Broadcast(T.float32(0), 5)) + T.bitwise_and(A[0:5] < T.Broadcast(T.float32(0), 5), T.Broadcast(T.bool(1), 5)) + T.bitwise_or(A[0:5] > T.Broadcast(T.float32(0), 5), T.Broadcast(T.bool(1), 5)) + + +def test_slice_op(): + tvm.ir.assert_structural_equal(slice_op_test, slice_op_test_ref) + + if __name__ == "__main__": test_get_valid_counts_script_func() test_alloc_zero_dim_buffer_round_trip() + test_slice_op()