Skip to content

Commit

Permalink
[TIR] Enhance TVMScript Buffer Slice Access (#14693)
Browse files Browse the repository at this point in the history
  • Loading branch information
lightzhan-intellif authored Jun 4, 2023
1 parent 9877db5 commit 429f601
Show file tree
Hide file tree
Showing 5 changed files with 217 additions and 9 deletions.
38 changes: 38 additions & 0 deletions python/tvm/script/parser/core/evaluator.py
Original file line number Diff line number Diff line change
Expand Up @@ -158,6 +158,44 @@ def _visit(self, node: doc.AST) -> Any:
res : Any
The evaluation result.
"""
args = []
if (
isinstance(node, doc.Call)
and hasattr(node.func, "attr")
and node.func.attr not in ["reads", "writes", "match_buffer", "realize"]
) or isinstance(node, (doc.BinOp, doc.UnaryOp, doc.Compare, doc.BoolOp)):
if isinstance(node, doc.BinOp):
args = [node.left, node.right]
elif isinstance(node, doc.UnaryOp):
args = [node.operand]
elif isinstance(node, doc.Compare):
args = [node.left, *node.comparators]
else:
if isinstance(node, doc.Call):
args = node.args
elif isinstance(node, doc.BoolOp):
args = node.values
for arg in args:
if isinstance(arg, doc.Subscript) and isinstance(arg.slice, (doc.Slice, doc.Tuple)):
if isinstance(arg.slice, doc.Slice):
check_slices = [arg.slice]
else:
check_slices = []
for p in arg.slice.elts:
if isinstance(p, doc.Slice):
check_slices.append(p)
for s in check_slices:
if not s.step and s.upper and s.lower:
s.step = doc.Constant(
1,
None,
1,
1,
s.upper.lineno,
s.upper.end_col_offset + 1,
s.upper.lineno,
s.upper.end_col_offset + 2,
)
if isinstance(node, list):
return [self._visit(n) for n in node]
if isinstance(node, tuple):
Expand Down
92 changes: 84 additions & 8 deletions python/tvm/script/parser/tir/operation.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,9 @@
from typing import Type

from tvm import tir
from tvm._ffi.runtime_ctypes import DataType, DataTypeCode
from tvm.tir import IntImm
from tvm.tir.expr import FloatImm

from .._core import OpMethod, doc, register_op

Expand All @@ -32,14 +34,88 @@ def _and(a, b):
a = IntImm("bool", a)
if isinstance(b, bool):
b = IntImm("bool", b)
return tir.And(a, b)
if DataType(a.dtype).lanes > 1 or DataType(b.dtype).lanes > 1:
return a & b
else:
return tir.And(a, b)

def _or(a, b):
if isinstance(a, bool):
a = IntImm("bool", a)
if isinstance(b, bool):
b = IntImm("bool", b)
return tir.Or(a, b)
if DataType(a.dtype).lanes > 1 or DataType(b.dtype).lanes > 1:
return a | b
else:
return tir.Or(a, b)

def _get_type_str(dtype: str):
if DataType(dtype).lanes == 1:
return dtype
index = dtype.find("x")
return dtype[0:index]

def _auto_broadcast(a, b, op):

if isinstance(a, int):
if hasattr(b, "dtype"):
if (
DataType(b.dtype).type_code == DataTypeCode.INT
or DataType(b.dtype).type_code == DataTypeCode.UINT
):
a = IntImm(_get_type_str(b.dtype), a)
elif DataType(b.dtype).type_code == DataTypeCode.FLOAT:
a = FloatImm(_get_type_str(b.dtype), a)
elif isinstance(b, float):
a = FloatImm("float32", a)
else:
a = IntImm("int32", a)
elif isinstance(a, float):
if DataType(b.dtype).type_code == DataTypeCode.FLOAT:
a = FloatImm(_get_type_str(b.dtype), a)
else:
a = FloatImm("float32", a)

assert isinstance(a, tir.PrimExpr), "Operand should be a PrimExpr."
if isinstance(b, int):
if (
DataType(a.dtype).type_code == DataTypeCode.INT
or DataType(a.dtype).type_code == DataTypeCode.UINT
):
b = IntImm(_get_type_str(a.dtype), b)
elif DataType(a.dtype).type_code == DataTypeCode.FLOAT:
b = FloatImm(_get_type_str(a.dtype), b)
elif isinstance(b, float):
b = FloatImm(_get_type_str(a.dtype), b)

if DataType(a.dtype).lanes == DataType(b.dtype).lanes:
return op(a, b)
elif DataType(a.dtype).lanes == 1 and DataType(a.dtype).lanes != DataType(b.dtype).lanes:
broadcast_a = tir.Broadcast(a, DataType(b.dtype).lanes)
return op(broadcast_a, b)
elif DataType(b.dtype).lanes == 1 and DataType(a.dtype).lanes != DataType(b.dtype).lanes:
broadcast_b = tir.Broadcast(b, DataType(a.dtype).lanes)
return op(a, broadcast_b)
else:
raise TypeError("do not know how to deal with it.")

def _eq(a, b):
return _auto_broadcast(a, b, tir.EQ)

def _ne(a, b):
return _auto_broadcast(a, b, tir.NE)

def _lt(a, b):
return _auto_broadcast(a, b, tir.LT)

def _le(a, b):
return _auto_broadcast(a, b, tir.LE)

def _gt(a, b):
return _auto_broadcast(a, b, tir.GT)

def _ge(a, b):
return _auto_broadcast(a, b, tir.GE)

def r(op: Type, i: int, m: OpMethod): # pylint: disable=invalid-name
register_op(ty, op, i)(m)
Expand All @@ -60,12 +136,12 @@ def r(op: Type, i: int, m: OpMethod): # pylint: disable=invalid-name
# doc.MatMult <-- not implemented
# doc.Pow <-- not implemented
# Case 2. cmpop
r(doc.Eq, i, tir.EQ)
r(doc.NotEq, i, tir.NE)
r(doc.Lt, i, tir.LT)
r(doc.LtE, i, tir.LE)
r(doc.Gt, i, tir.GT)
r(doc.GtE, i, tir.GE)
r(doc.Eq, i, _eq)
r(doc.NotEq, i, _ne)
r(doc.Lt, i, _lt)
r(doc.LtE, i, _le)
r(doc.Gt, i, _gt)
r(doc.GtE, i, _ge)
# doc.Is <-- not implemented
# doc.IsNot <-- not implemented
# doc.In <-- not implemented
Expand Down
22 changes: 22 additions & 0 deletions python/tvm/script/parser/tir/parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -211,6 +211,28 @@ def visit_assign(self: Parser, node: doc.Assign) -> None:
if len(node.targets) != 1:
self.report_error(node, "Consequential assignments like 'a = b = c' are not supported.")
lhs = node.targets[0]

if isinstance(node.value, doc.Subscript):
check_slices = []
if isinstance(node.value.slice, doc.Slice):
check_slices = [node.value.slice]
elif isinstance(node.value.slice, doc.Tuple):
for p in node.value.slice.elts:
if isinstance(p, doc.Slice):
check_slices.append(p)
for s in check_slices:
if not s.step and s.upper and s.lower:
s.step = doc.Constant(
1,
None,
1,
1,
s.upper.lineno,
s.upper.end_col_offset + 1,
s.upper.lineno,
s.upper.end_col_offset + 2,
)

rhs = self.eval_expr(node.value)
if isinstance(lhs, doc.Subscript):
if isinstance(lhs.slice, doc.Tuple):
Expand Down
5 changes: 4 additions & 1 deletion python/tvm/tir/buffer.py
Original file line number Diff line number Diff line change
Expand Up @@ -203,11 +203,14 @@ def __getitem__(self, indices):
return BufferRegion(self, region)
else:
expr_indices = []
for index in indices:
for i, index in enumerate(indices):
if isinstance(index, slice):
start = 0 if index.start is None else index.start
stop = self.shape[i] if index.stop is None else index.stop
step = 1 if index.step is None else index.step
# We should ensure the dtype of start is the same with that of step.
if isinstance(start, tvm.tir.expr.PrimExpr) and isinstance(step, int):
step = tvm.tir.expr.IntImm(start.dtype, step)
lanes = analyzer.simplify((stop - start + step - 1) // step)
if lanes == 1:
expr_indices.append(start)
Expand Down
69 changes: 69 additions & 0 deletions tests/python/unittest/test_tvmscript_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -177,6 +177,75 @@ def test_ceildiv():
tvm.testing.assert_allclose(a.numpy(), ref)


@T.prim_func
def slice_op_test(
A: T.Buffer((10,), "float32"), B: T.Buffer((10,), "float32"), C: T.Buffer((10,), "uint32")
):
B[0:5] = A[0:5] + B[0:5]
B[0:5] = A[0:5] - B[0:5]
B[0:5] = A[0:5] * B[0:5]
B[0:5] = A[0:5] / B[0:5]
C[0:5] = C[0:5] % T.broadcast(T.uint32(5), 5)
B[0:5] = -B[0:5]
C[0:5] = C[0:5] >> 4
C[0:5] = C[0:5] << 4
C[0:5] = C[0:5] << C[0:5]
C[0:5] = C[0:5] >> C[0:5]
T.evaluate(A[0:5] > B[0:5])
T.evaluate(A[0:5] > 5)
T.evaluate(A[0:5] >= B[0:5])
T.evaluate(A[0:5] >= 5)
T.evaluate(A[0:5] < B[0:5])
T.evaluate(A[0:5] < 5)
T.evaluate(A[0:5] <= B[0:5])
T.evaluate(A[0:5] <= 5)
T.evaluate(A[0:5] == B[0:5])
T.evaluate(A[0:5] == 5)
T.evaluate(A[0:5] != B[0:5])
T.evaluate(A[0:5] != 5)
T.evaluate((A[0:5] > 0) and (B[0:5] > 0))
T.evaluate((A[0:5] > 0) or (B[0:5] > 0))
T.evaluate((A[0:5] < 0) and (1 > 0))
T.evaluate((A[0:5] > 0) or (1 > 0))


@T.prim_func
def slice_op_test_ref(
A: T.Buffer((10,), "float32"), B: T.Buffer((10,), "float32"), C: T.Buffer((10,), "uint32")
):
B[0:5] = A[0:5] + B[0:5]
B[0:5] = A[0:5] - B[0:5]
B[0:5] = A[0:5] * B[0:5]
B[0:5] = A[0:5] / B[0:5]
C[0:5] = C[0:5] % T.Broadcast(T.uint32(5), 5)
B[0:5] = B[0:5] * T.Broadcast(T.float32(-1), 5)
C[0:5] = T.shift_right(C[0:5], T.Broadcast(T.uint32(4), 5))
C[0:5] = T.shift_left(C[0:5], T.Broadcast(T.uint32(4), 5))
C[0:5] = T.shift_left(C[0:5], C[0:5])
C[0:5] = T.shift_right(C[0:5], C[0:5])
T.evaluate(A[0:5] > B[0:5])
T.evaluate(A[0:5] > T.Broadcast(T.float32(5), 5))
T.evaluate(A[0:5] >= B[0:5])
T.evaluate(A[0:5] >= T.Broadcast(T.float32(5), 5))
T.evaluate(A[0:5] < B[0:5])
T.evaluate(A[0:5] < T.Broadcast(T.float32(5), 5))
T.evaluate(A[0:5] <= B[0:5])
T.evaluate(A[0:5] <= T.Broadcast(T.float32(5), 5))
T.evaluate(A[0:5] == B[0:5])
T.evaluate(A[0:5] == T.Broadcast(T.float32(5), 5))
T.evaluate(A[0:5] != B[0:5])
T.evaluate(A[0:5] != T.Broadcast(T.float32(5), 5))
T.bitwise_and(A[0:5] > T.Broadcast(T.float32(0), 5), B[0:5] > T.Broadcast(T.float32(0), 5))
T.bitwise_or(A[0:5] > T.Broadcast(T.float32(0), 5), B[0:5] > T.Broadcast(T.float32(0), 5))
T.bitwise_and(A[0:5] < T.Broadcast(T.float32(0), 5), T.Broadcast(T.bool(1), 5))
T.bitwise_or(A[0:5] > T.Broadcast(T.float32(0), 5), T.Broadcast(T.bool(1), 5))


def test_slice_op():
tvm.ir.assert_structural_equal(slice_op_test, slice_op_test_ref)


if __name__ == "__main__":
test_get_valid_counts_script_func()
test_alloc_zero_dim_buffer_round_trip()
test_slice_op()

0 comments on commit 429f601

Please sign in to comment.