Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[TVMScript] Represent ramp as index slice #11308

Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
20 changes: 17 additions & 3 deletions python/tvm/script/parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -631,7 +631,14 @@ def transform_SubscriptAssign(self, node):
f"cannot be indexed by {len(indexes)}-dimensional indices.",
node.params[1].span,
)

def __convert_index(x):
if isinstance(x, Slice):
return x.as_index_expr(self.report_error)
return x

# BufferStore
indexes = [__convert_index(x) for x in indexes]
return tvm.tir.BufferStore(
symbol,
tvm.runtime.convert(rhs, span=rhs_span),
Expand Down Expand Up @@ -948,11 +955,18 @@ def f():
)

def transform_Slice(self, node):
"""Index slice visitor."""
start = self.transform(node.start)
end = self.transform(node.end)
if not (isinstance(node.step, ast.Constant) and node.step.value == 1):
self.report_error("Only step size 1 is supported for slices.", node.step.span)
return Slice(start, end)
if not (
isinstance(node.step, ast.Constant)
and isinstance(node.step.value, int)
and node.step.value > 0
):
self.report_error(
"Only positive integer step size is supported for slices.", node.step.span
)
return Slice(start, end, node.step.value, tvm_span_from_synr(node.span))

def transform_Subscript(self, node):
"""Array access visitor.
Expand Down
64 changes: 56 additions & 8 deletions python/tvm/script/tir/node.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,10 +19,10 @@

from typing import Optional, Union, List, Callable
import synr

from tvm.arith import Analyzer
from tvm.runtime import ObjectGeneric, convert
from tvm.tir import PrimExpr, Buffer, BufferLoad
from tvm.ir import Span
from tvm.tir import PrimExpr, Buffer, BufferLoad, IntImm, Ramp, BufferRegion
from tvm.ir import Span, Range


class Slice:
Expand All @@ -36,24 +36,50 @@ class Slice:
stop : Optional[Union[PrimExpr, int]]
The stop index, None means the Slice is an element-wise index

step : int
The slice step

span : Optional[Span]
The location of the slice in the source.
"""

start: Union[PrimExpr, int]
stop: Optional[Union[PrimExpr, int]]
step: int
span: Optional[Span]

def __init__(
self,
start: Union[PrimExpr, int],
stop: Optional[Union[PrimExpr, int]] = None,
step: int = 1,
span: Optional[Span] = None,
):
self.start = start
self.stop = stop
self.step = step
self.span = span

def as_index_expr(self, report_error: Callable[[str, Union[Span, synr.ast.Span]], None]):
"""Helper to create index PrimExpr from slice object
Parameters
----------
report_error: Callable[[str, Union[Span, synr.ast.Span]], None]
The error report func
"""
if self.stop is None:
# scalar index
return self.start
extent = Analyzer().simplify(self.stop - self.start)
if not isinstance(extent, (int, IntImm)):
report_error("Slice's extent should be constant for buffer indices", self.span)
if self.step < 1:
report_error("Slice's step should be positive integer", self.span)
lanes = (int(extent) + self.step - 1) // self.step
if lanes == 1:
return self.start
return Ramp(self.start, self.step, lanes, self.span)


class BufferSlice(ObjectGeneric):
"""A generic object for representing general buffer access. Following cases are supported:
Expand Down Expand Up @@ -148,13 +174,35 @@ def __str__(self):

def asobject(self) -> BufferLoad:
"""Convert object."""
for s in self.slices:
if s.stop is not None:
self.report_error("BufferLoad only accepts elementwise access", self.span)

indices = [s.start for s in self.slices]
indices = [s.as_index_expr(self.report_error) for s in self.slices]
return BufferLoad(self.buffer, indices, span=self.span)

def as_buffer_region(self, analyzer: Optional[Analyzer] = None) -> BufferRegion:
"""Construct BufferRegion from BufferSlice

Parameters
----------
analyzer : Optional[tvm.arith.Analyzer]
The analyzer for simplifying. If not provided, the method will construct a new one

Returns
-------
buffer_region : BufferRegion
The constructed BufferRegion.
"""
region: List[Range] = []
for s in self.slices:
start = s.start if isinstance(s.start, PrimExpr) else IntImm("int32", s.start)
extent = IntImm(start.dtype, 1) if s.stop is None else s.stop - s.start
if not analyzer:
analyzer = Analyzer()
if isinstance(extent, PrimExpr):
extent = analyzer.simplify(extent)
if s.step != 1:
self.report_error("BufferRegion do not support non-trivial stride", s.span)
region.append(Range.from_min_extent(start, extent, span=s.span))
return BufferRegion(self.buffer, region)

def astype(self, dtype: str, span: Optional[Span] = None) -> PrimExpr:
return self.asobject().astype(dtype, span)

Expand Down
7 changes: 2 additions & 5 deletions python/tvm/script/tir/scope_handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,6 @@
from tvm.tir import Stmt, PrimExpr, IterVar, Var, Buffer, BufferRegion, ForKind

from .node import BufferSlice
from .utils import buffer_slice_to_region

from ..context_maintainer import ContextMaintainer
from ..registry import register
Expand Down Expand Up @@ -327,12 +326,10 @@ def block(name_hint: str = "", span: Optional[Span] = None):

# create block read/write regions
reads: List[BufferRegion] = (
[buffer_slice_to_region(read) for read in block_info.reads]
if block_info.reads
else []
[read.as_buffer_region() for read in block_info.reads] if block_info.reads else []
)
writes: List[BufferRegion] = (
[buffer_slice_to_region(write) for write in block_info.writes]
[write.as_buffer_region() for write in block_info.writes]
if block_info.writes
else []
)
Expand Down
3 changes: 1 addition & 2 deletions python/tvm/script/tir/special_stmt.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,6 @@
from tvm.tir import IntImm, IterVar, Var

from .node import BufferSlice
from .utils import buffer_slice_to_region

from ..context_maintainer import BlockInfo, ContextMaintainer
from ..registry import register
Expand Down Expand Up @@ -168,7 +167,7 @@ def match_buffer(
)
self.context.func_buffer_map[param] = buffer
elif isinstance(param, BufferSlice):
buffer_region = buffer_slice_to_region(param)
buffer_region = param.as_buffer_region()
self.context.current_block_scope().match_buffers.append(
tvm.tir.MatchBufferRegion(buffer, buffer_region)
)
Expand Down
55 changes: 0 additions & 55 deletions python/tvm/script/tir/utils.py

This file was deleted.

29 changes: 27 additions & 2 deletions src/printer/tvmscript_printer.cc
Original file line number Diff line number Diff line change
Expand Up @@ -265,6 +265,7 @@ class TVMScriptPrinter : public StmtFunctor<Doc(const Stmt&)>,
Doc PrintRange(const RangeNode* op);
Doc PrintArray(const ArrayNode* op);
Doc PrintBuffer(const BufferNode* op);
Doc PrintBufferIndices(const Array<PrimExpr>& indices);
Doc PrintNonHeaderBufferDeclarations(const Array<Buffer>& aliasing_buffers);
Doc AllocBufferDeclaration(const Buffer& buf);
Doc PrintBlockVar(const IterVar& iter_var, const PrimExpr& value);
Expand Down Expand Up @@ -834,7 +835,7 @@ Doc TVMScriptPrinter::VisitExpr_(const BufferLoadNode* op, ExprPrecedence* out_p
if (op->indices.size() == 0) {
doc << Print(op->buffer) << "[()]";
} else {
doc << Print(op->buffer) << Print(op->indices);
doc << Print(op->buffer) << PrintBufferIndices(op->indices);
}
return doc;
}
Expand Down Expand Up @@ -1260,7 +1261,7 @@ Doc TVMScriptPrinter::VisitStmt_(const BufferStoreNode* op) {
if (op->indices.size() == 0) {
doc << Print(op->buffer) << "[()] = " << Print(op->value);
} else {
doc << Print(op->buffer) << Print(op->indices) << " = " << Print(op->value);
doc << Print(op->buffer) << PrintBufferIndices(op->indices) << " = " << Print(op->value);
}
return doc;
}
Expand Down Expand Up @@ -1678,6 +1679,30 @@ Doc TVMScriptPrinter::PrintBuffer(const BufferNode* op) {
return meta_.InMeta(buffer) ? meta_.GetMetaNode(buffer) : AllocBuf(buffer);
}

Doc TVMScriptPrinter::PrintBufferIndices(const Array<PrimExpr>& indices) {
Doc doc;
doc << '[';
for (size_t i = 0; i < indices.size(); ++i) {
if (i != 0) {
doc << ", ";
}
PrimExpr index = indices[i];
if (const RampNode* ramp = index.as<RampNode>()) {
// specify ramp printing as python index slice
if (auto* stride_imm = ramp->stride.as<IntImmNode>()) {
doc << Print(ramp->base) << ":" << Print(ramp->base + ramp->lanes * ramp->stride);
if (stride_imm->value != 1) {
doc << ":" << Print(ramp->stride);
}
continue;
}
}
doc << Print(index);
}
doc << ']';
return doc;
}

Doc TVMScriptPrinter::PrintNonHeaderBufferDeclarations(const Array<Buffer>& aliasing_buffers) {
Doc decls;
for (const auto& buf_usage : aliasing_buffers) {
Expand Down
60 changes: 37 additions & 23 deletions tests/python/unittest/test_tvmscript_error_report.py
Original file line number Diff line number Diff line change
Expand Up @@ -372,29 +372,6 @@ def test_error_index_type():
check_error(error_bufferslice_index_type, 8)


def error_index_with_stop() -> None:
A = T.alloc_buffer((128, 128), "float32")
for i, j in T.grid(128, 128):
with T.block():
vi, vj = T.axis.remap("SS", [i, j])
A[vi, vj] = A[vi, 1:10] + 1 # error


def error_bufferslice_index_with_stop() -> None:
A = T.alloc_buffer((1,), "int32")
B = T.alloc_buffer((16, 16), "float32")
C = T.alloc_buffer((16, 16), "float32")
for i, j in T.grid(16, 16):
with T.block():
vi, vj = T.axis.remap("SS", [i, j])
C[vi, vj] = B[vi, A[0:1]] # error


def test_error_index_with_stop_slice():
check_error(error_index_with_stop, 6)
check_error(error_bufferslice_index_with_stop, 8)


def special_stmt_except() -> None:
A = T.alloc_buffer("(128, 128)", "float32") # error
T.evaluate(1.0)
Expand Down Expand Up @@ -658,5 +635,42 @@ def test_preflattened_buffer_map_offset_factor():
check_error(preflattened_buffer_map_offset_factor_nonint, 3)


def strided_buffer_region(A: T.handle):
# do not allow stride in buffer region
A = T.match_buffer((128, 128), "int32")
with T.block():
T.reads([])
T.writes([A[0:128:2, 0:128:3]]) # error
T.evaluate(T.call_extern("strided_compute", dtype=""))


def strided_buffer_region(A: T.handle):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

the duplicated test

# do not allow stride in buffer region
A = T.match_buffer((128, 128), "int32")
with T.block():
T.reads([])
T.writes([A[0:128:2, 0:128:3]]) # error
T.evaluate(T.call_extern("strided_compute", dtype=""))


def access_reversed_slice(A: T.handle):
# do not allow reversed slice step
A = T.match_buffer((128,), "int32")
A[0:128:-1] = T.broadcast(1, 128)


def access_non_const_slice(A: T.handle):
# do not allow reversed slice step
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

error comment?

A = T.match_buffer((128,), "int32")
for i in range(4):
T.evaluate(A[0:i:1])


def test_illegal_buffer_slice():
check_error(strided_buffer_region, 3)
check_error(access_reversed_slice, 3)
check_error(access_non_const_slice, 3)


if __name__ == "__main__":
sys.exit(pytest.main([__file__] + sys.argv[1:]))
17 changes: 17 additions & 0 deletions tests/python/unittest/test_tvmscript_roundtrip.py
Original file line number Diff line number Diff line change
Expand Up @@ -3272,6 +3272,22 @@ def element_wise(a: T.handle, c: T.handle) -> None:
return element_wise


def buffer_ramp_access_as_slice_index():
@T.prim_func
def buffer_ramp_access(a: T.handle, b: T.handle, c: T.handle) -> None:
A = T.match_buffer(a, (128,), "float32")
B = T.match_buffer(b, (128,), "float32")
C = T.match_buffer(c, (128,), "float32")
for i in range(128):
A[i : i + 1 : 1] = i
for i in range(4):
B[i * 32 : i * 32 + 32] = A[i * 32 : i * 32 + 32 : 1] + T.broadcast(1.0, 32)
for i in range(4):
C[i : i + 128 : 4] = B[i : i + 128 : 4] + T.broadcast(1.0, 32)

return buffer_ramp_access


ir_generator = tvm.testing.parameter(
opt_gemm_normalize,
opt_gemm_lower,
Expand Down Expand Up @@ -3308,6 +3324,7 @@ def element_wise(a: T.handle, c: T.handle) -> None:
string_annotation_escaping,
pointer_type,
buffer_axis_separator,
buffer_ramp_access_as_slice_index,
)


Expand Down