Skip to content

Commit

Permalink
[TVMScript] Represent ramp as index slice (apache#11308)
Browse files Browse the repository at this point in the history
* support represent ramp as index slice in tvmscript

* fix testcase's comment, check slice lanes instead of extent
  • Loading branch information
wrongtest-intellif authored and Yuanjing Shi committed May 17, 2022
1 parent 1d348e4 commit 4893509
Show file tree
Hide file tree
Showing 8 changed files with 147 additions and 98 deletions.
20 changes: 17 additions & 3 deletions python/tvm/script/parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -631,7 +631,14 @@ def transform_SubscriptAssign(self, node):
f"cannot be indexed by {len(indexes)}-dimensional indices.",
node.params[1].span,
)

def __convert_index(x):
if isinstance(x, Slice):
return x.as_index_expr(self.report_error)
return x

# BufferStore
indexes = [__convert_index(x) for x in indexes]
return tvm.tir.BufferStore(
symbol,
tvm.runtime.convert(rhs, span=rhs_span),
Expand Down Expand Up @@ -948,11 +955,18 @@ def f():
)

def transform_Slice(self, node):
"""Index slice visitor."""
start = self.transform(node.start)
end = self.transform(node.end)
if not (isinstance(node.step, ast.Constant) and node.step.value == 1):
self.report_error("Only step size 1 is supported for slices.", node.step.span)
return Slice(start, end)
if not (
isinstance(node.step, ast.Constant)
and isinstance(node.step.value, int)
and node.step.value > 0
):
self.report_error(
"Only positive integer step size is supported for slices.", node.step.span
)
return Slice(start, end, node.step.value, tvm_span_from_synr(node.span))

def transform_Subscript(self, node):
"""Array access visitor.
Expand Down
63 changes: 55 additions & 8 deletions python/tvm/script/tir/node.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,10 +19,10 @@

from typing import Optional, Union, List, Callable
import synr

from tvm.arith import Analyzer
from tvm.runtime import ObjectGeneric, convert
from tvm.tir import PrimExpr, Buffer, BufferLoad
from tvm.ir import Span
from tvm.tir import PrimExpr, Buffer, BufferLoad, IntImm, Ramp, BufferRegion
from tvm.ir import Span, Range


class Slice:
Expand All @@ -36,24 +36,49 @@ class Slice:
stop : Optional[Union[PrimExpr, int]]
The stop index, None means the Slice is an element-wise index
step : int
The slice step
span : Optional[Span]
The location of the slice in the source.
"""

start: Union[PrimExpr, int]
stop: Optional[Union[PrimExpr, int]]
step: int
span: Optional[Span]

def __init__(
self,
start: Union[PrimExpr, int],
stop: Optional[Union[PrimExpr, int]] = None,
step: int = 1,
span: Optional[Span] = None,
):
self.start = start
self.stop = stop
self.step = step
self.span = span

def as_index_expr(self, report_error: Callable[[str, Union[Span, synr.ast.Span]], None]):
"""Helper to create index PrimExpr from slice object
Parameters
----------
report_error: Callable[[str, Union[Span, synr.ast.Span]], None]
The error report func
"""
if self.stop is None:
# scalar index
return self.start
if self.step < 1:
report_error("Slice's step should be positive integer", self.span)
lanes = Analyzer().simplify((self.stop - self.start + self.step - 1) // self.step)
if not isinstance(lanes, (int, IntImm)):
report_error("Slice's lanes should be constant for buffer indices", self.span)
if lanes == 1:
return self.start
return Ramp(self.start, self.step, int(lanes), self.span)


class BufferSlice(ObjectGeneric):
"""A generic object for representing general buffer access. Following cases are supported:
Expand Down Expand Up @@ -148,13 +173,35 @@ def __str__(self):

def asobject(self) -> BufferLoad:
"""Convert object."""
for s in self.slices:
if s.stop is not None:
self.report_error("BufferLoad only accepts elementwise access", self.span)

indices = [s.start for s in self.slices]
indices = [s.as_index_expr(self.report_error) for s in self.slices]
return BufferLoad(self.buffer, indices, span=self.span)

def as_buffer_region(self, analyzer: Optional[Analyzer] = None) -> BufferRegion:
"""Construct BufferRegion from BufferSlice
Parameters
----------
analyzer : Optional[tvm.arith.Analyzer]
The analyzer for simplifying. If not provided, the method will construct a new one
Returns
-------
buffer_region : BufferRegion
The constructed BufferRegion.
"""
region: List[Range] = []
for s in self.slices:
start = s.start if isinstance(s.start, PrimExpr) else IntImm("int32", s.start)
extent = IntImm(start.dtype, 1) if s.stop is None else s.stop - s.start
if not analyzer:
analyzer = Analyzer()
if isinstance(extent, PrimExpr):
extent = analyzer.simplify(extent)
if s.step != 1:
self.report_error("BufferRegion do not support non-trivial stride", s.span)
region.append(Range.from_min_extent(start, extent, span=s.span))
return BufferRegion(self.buffer, region)

def astype(self, dtype: str, span: Optional[Span] = None) -> PrimExpr:
return self.asobject().astype(dtype, span)

Expand Down
7 changes: 2 additions & 5 deletions python/tvm/script/tir/scope_handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,6 @@
from tvm.tir import Stmt, PrimExpr, IterVar, Var, Buffer, BufferRegion, ForKind

from .node import BufferSlice
from .utils import buffer_slice_to_region

from ..context_maintainer import ContextMaintainer
from ..registry import register
Expand Down Expand Up @@ -327,12 +326,10 @@ def block(name_hint: str = "", span: Optional[Span] = None):

# create block read/write regions
reads: List[BufferRegion] = (
[buffer_slice_to_region(read) for read in block_info.reads]
if block_info.reads
else []
[read.as_buffer_region() for read in block_info.reads] if block_info.reads else []
)
writes: List[BufferRegion] = (
[buffer_slice_to_region(write) for write in block_info.writes]
[write.as_buffer_region() for write in block_info.writes]
if block_info.writes
else []
)
Expand Down
3 changes: 1 addition & 2 deletions python/tvm/script/tir/special_stmt.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,6 @@
from tvm.tir import IntImm, IterVar, Var

from .node import BufferSlice
from .utils import buffer_slice_to_region

from ..context_maintainer import BlockInfo, ContextMaintainer
from ..registry import register
Expand Down Expand Up @@ -168,7 +167,7 @@ def match_buffer(
)
self.context.func_buffer_map[param] = buffer
elif isinstance(param, BufferSlice):
buffer_region = buffer_slice_to_region(param)
buffer_region = param.as_buffer_region()
self.context.current_block_scope().match_buffers.append(
tvm.tir.MatchBufferRegion(buffer, buffer_region)
)
Expand Down
55 changes: 0 additions & 55 deletions python/tvm/script/tir/utils.py

This file was deleted.

29 changes: 27 additions & 2 deletions src/printer/tvmscript_printer.cc
Original file line number Diff line number Diff line change
Expand Up @@ -265,6 +265,7 @@ class TVMScriptPrinter : public StmtFunctor<Doc(const Stmt&)>,
Doc PrintRange(const RangeNode* op);
Doc PrintArray(const ArrayNode* op);
Doc PrintBuffer(const BufferNode* op);
Doc PrintBufferIndices(const Array<PrimExpr>& indices);
Doc PrintNonHeaderBufferDeclarations(const Array<Buffer>& aliasing_buffers);
Doc AllocBufferDeclaration(const Buffer& buf);
Doc PrintBlockVar(const IterVar& iter_var, const PrimExpr& value);
Expand Down Expand Up @@ -834,7 +835,7 @@ Doc TVMScriptPrinter::VisitExpr_(const BufferLoadNode* op, ExprPrecedence* out_p
if (op->indices.size() == 0) {
doc << Print(op->buffer) << "[()]";
} else {
doc << Print(op->buffer) << Print(op->indices);
doc << Print(op->buffer) << PrintBufferIndices(op->indices);
}
return doc;
}
Expand Down Expand Up @@ -1260,7 +1261,7 @@ Doc TVMScriptPrinter::VisitStmt_(const BufferStoreNode* op) {
if (op->indices.size() == 0) {
doc << Print(op->buffer) << "[()] = " << Print(op->value);
} else {
doc << Print(op->buffer) << Print(op->indices) << " = " << Print(op->value);
doc << Print(op->buffer) << PrintBufferIndices(op->indices) << " = " << Print(op->value);
}
return doc;
}
Expand Down Expand Up @@ -1678,6 +1679,30 @@ Doc TVMScriptPrinter::PrintBuffer(const BufferNode* op) {
return meta_.InMeta(buffer) ? meta_.GetMetaNode(buffer) : AllocBuf(buffer);
}

Doc TVMScriptPrinter::PrintBufferIndices(const Array<PrimExpr>& indices) {
Doc doc;
doc << '[';
for (size_t i = 0; i < indices.size(); ++i) {
if (i != 0) {
doc << ", ";
}
PrimExpr index = indices[i];
if (const RampNode* ramp = index.as<RampNode>()) {
// specify ramp printing as python index slice
if (auto* stride_imm = ramp->stride.as<IntImmNode>()) {
doc << Print(ramp->base) << ":" << Print(ramp->base + ramp->lanes * ramp->stride);
if (stride_imm->value != 1) {
doc << ":" << Print(ramp->stride);
}
continue;
}
}
doc << Print(index);
}
doc << ']';
return doc;
}

Doc TVMScriptPrinter::PrintNonHeaderBufferDeclarations(const Array<Buffer>& aliasing_buffers) {
Doc decls;
for (const auto& buf_usage : aliasing_buffers) {
Expand Down
51 changes: 28 additions & 23 deletions tests/python/unittest/test_tvmscript_error_report.py
Original file line number Diff line number Diff line change
Expand Up @@ -372,29 +372,6 @@ def test_error_index_type():
check_error(error_bufferslice_index_type, 8)


def error_index_with_stop() -> None:
A = T.alloc_buffer((128, 128), "float32")
for i, j in T.grid(128, 128):
with T.block():
vi, vj = T.axis.remap("SS", [i, j])
A[vi, vj] = A[vi, 1:10] + 1 # error


def error_bufferslice_index_with_stop() -> None:
A = T.alloc_buffer((1,), "int32")
B = T.alloc_buffer((16, 16), "float32")
C = T.alloc_buffer((16, 16), "float32")
for i, j in T.grid(16, 16):
with T.block():
vi, vj = T.axis.remap("SS", [i, j])
C[vi, vj] = B[vi, A[0:1]] # error


def test_error_index_with_stop_slice():
check_error(error_index_with_stop, 6)
check_error(error_bufferslice_index_with_stop, 8)


def special_stmt_except() -> None:
A = T.alloc_buffer("(128, 128)", "float32") # error
T.evaluate(1.0)
Expand Down Expand Up @@ -658,5 +635,33 @@ def test_preflattened_buffer_map_offset_factor():
check_error(preflattened_buffer_map_offset_factor_nonint, 3)


def strided_buffer_region(A: T.handle):
# do not allow stride in buffer region
A = T.match_buffer((128, 128), "int32")
with T.block():
T.reads([])
T.writes([A[0:128:2, 0:128:3]]) # error
T.evaluate(T.call_extern("strided_compute", dtype=""))


def access_reversed_slice(A: T.handle):
# do not allow reversed slice step
A = T.match_buffer((128,), "int32")
A[0:128:-1] = T.broadcast(1, 128) # error


def access_non_const_slice_length(A: T.handle):
# do not allow non-constant slice length
A = T.match_buffer((128,), "int32")
for i in range(4):
T.evaluate(A[0:i:1]) # error


def test_illegal_buffer_slice():
check_error(strided_buffer_region, 3)
check_error(access_reversed_slice, 3)
check_error(access_non_const_slice_length, 3)


if __name__ == "__main__":
sys.exit(pytest.main([__file__] + sys.argv[1:]))
17 changes: 17 additions & 0 deletions tests/python/unittest/test_tvmscript_roundtrip.py
Original file line number Diff line number Diff line change
Expand Up @@ -3272,6 +3272,22 @@ def element_wise(a: T.handle, c: T.handle) -> None:
return element_wise


def buffer_ramp_access_as_slice_index():
@T.prim_func
def buffer_ramp_access(a: T.handle, b: T.handle, c: T.handle) -> None:
A = T.match_buffer(a, (128,), "float32")
B = T.match_buffer(b, (128,), "float32")
C = T.match_buffer(c, (128,), "float32")
for i in range(128):
A[i : i + 1 : 1] = i
for i in range(4):
B[i * 32 : i * 32 + 32] = A[i * 32 : i * 32 + 32 : 1] + T.broadcast(1.0, 32)
for i in range(4):
C[i : i + 128 : 4] = B[i : i + 128 : 4] + T.broadcast(1.0, 32)

return buffer_ramp_access


ir_generator = tvm.testing.parameter(
opt_gemm_normalize,
opt_gemm_lower,
Expand Down Expand Up @@ -3308,6 +3324,7 @@ def element_wise(a: T.handle, c: T.handle) -> None:
string_annotation_escaping,
pointer_type,
buffer_axis_separator,
buffer_ramp_access_as_slice_index,
)


Expand Down

0 comments on commit 4893509

Please sign in to comment.