From 3cc3dc82534d611a05b304e5a8124086135eca5b Mon Sep 17 00:00:00 2001 From: wrongtest Date: Fri, 13 May 2022 19:08:31 +0800 Subject: [PATCH 1/2] support represent ramp as index slice in tvmscript --- python/tvm/script/parser.py | 20 +++++- python/tvm/script/tir/node.py | 64 ++++++++++++++++--- python/tvm/script/tir/scope_handler.py | 7 +- python/tvm/script/tir/special_stmt.py | 3 +- python/tvm/script/tir/utils.py | 55 ---------------- src/printer/tvmscript_printer.cc | 29 ++++++++- .../unittest/test_tvmscript_error_report.py | 60 ++++++++++------- .../unittest/test_tvmscript_roundtrip.py | 17 +++++ 8 files changed, 157 insertions(+), 98 deletions(-) delete mode 100644 python/tvm/script/tir/utils.py diff --git a/python/tvm/script/parser.py b/python/tvm/script/parser.py index c26812db4062..fe71b064320f 100644 --- a/python/tvm/script/parser.py +++ b/python/tvm/script/parser.py @@ -631,7 +631,14 @@ def transform_SubscriptAssign(self, node): f"cannot be indexed by {len(indexes)}-dimensional indices.", node.params[1].span, ) + + def __convert_index(x): + if isinstance(x, Slice): + return x.as_index_expr(self.report_error) + return x + # BufferStore + indexes = [__convert_index(x) for x in indexes] return tvm.tir.BufferStore( symbol, tvm.runtime.convert(rhs, span=rhs_span), @@ -948,11 +955,18 @@ def f(): ) def transform_Slice(self, node): + """Index slice visitor.""" start = self.transform(node.start) end = self.transform(node.end) - if not (isinstance(node.step, ast.Constant) and node.step.value == 1): - self.report_error("Only step size 1 is supported for slices.", node.step.span) - return Slice(start, end) + if not ( + isinstance(node.step, ast.Constant) + and isinstance(node.step.value, int) + and node.step.value > 0 + ): + self.report_error( + "Only positive integer step size is supported for slices.", node.step.span + ) + return Slice(start, end, node.step.value, tvm_span_from_synr(node.span)) def transform_Subscript(self, node): """Array access visitor. diff --git a/python/tvm/script/tir/node.py b/python/tvm/script/tir/node.py index 49b1b3a99d95..3779df6c3b05 100644 --- a/python/tvm/script/tir/node.py +++ b/python/tvm/script/tir/node.py @@ -19,10 +19,10 @@ from typing import Optional, Union, List, Callable import synr - +from tvm.arith import Analyzer from tvm.runtime import ObjectGeneric, convert -from tvm.tir import PrimExpr, Buffer, BufferLoad -from tvm.ir import Span +from tvm.tir import PrimExpr, Buffer, BufferLoad, IntImm, Ramp, BufferRegion +from tvm.ir import Span, Range class Slice: @@ -36,24 +36,50 @@ class Slice: stop : Optional[Union[PrimExpr, int]] The stop index, None means the Slice is an element-wise index + step : int + The slice step + span : Optional[Span] The location of the slice in the source. """ start: Union[PrimExpr, int] stop: Optional[Union[PrimExpr, int]] + step: int span: Optional[Span] def __init__( self, start: Union[PrimExpr, int], stop: Optional[Union[PrimExpr, int]] = None, + step: int = 1, span: Optional[Span] = None, ): self.start = start self.stop = stop + self.step = step self.span = span + def as_index_expr(self, report_error: Callable[[str, Union[Span, synr.ast.Span]], None]): + """Helper to create index PrimExpr from slice object + Parameters + ---------- + report_error: Callable[[str, Union[Span, synr.ast.Span]], None] + The error report func + """ + if self.stop is None: + # scalar index + return self.start + extent = Analyzer().simplify(self.stop - self.start) + if not isinstance(extent, (int, IntImm)): + report_error("Slice's extent should be constant for buffer indices", self.span) + if self.step < 1: + report_error("Slice's step should be positive integer", self.span) + lanes = (int(extent) + self.step - 1) // self.step + if lanes == 1: + return self.start + return Ramp(self.start, self.step, lanes, self.span) + class BufferSlice(ObjectGeneric): """A generic object for representing general buffer access. Following cases are supported: @@ -148,13 +174,35 @@ def __str__(self): def asobject(self) -> BufferLoad: """Convert object.""" - for s in self.slices: - if s.stop is not None: - self.report_error("BufferLoad only accepts elementwise access", self.span) - - indices = [s.start for s in self.slices] + indices = [s.as_index_expr(self.report_error) for s in self.slices] return BufferLoad(self.buffer, indices, span=self.span) + def as_buffer_region(self, analyzer: Optional[Analyzer] = None) -> BufferRegion: + """Construct BufferRegion from BufferSlice + + Parameters + ---------- + analyzer : Optional[tvm.arith.Analyzer] + The analyzer for simplifying. If not provided, the method will construct a new one + + Returns + ------- + buffer_region : BufferRegion + The constructed BufferRegion. + """ + region: List[Range] = [] + for s in self.slices: + start = s.start if isinstance(s.start, PrimExpr) else IntImm("int32", s.start) + extent = IntImm(start.dtype, 1) if s.stop is None else s.stop - s.start + if not analyzer: + analyzer = Analyzer() + if isinstance(extent, PrimExpr): + extent = analyzer.simplify(extent) + if s.step != 1: + self.report_error("BufferRegion do not support non-trivial stride", s.span) + region.append(Range.from_min_extent(start, extent, span=s.span)) + return BufferRegion(self.buffer, region) + def astype(self, dtype: str, span: Optional[Span] = None) -> PrimExpr: return self.asobject().astype(dtype, span) diff --git a/python/tvm/script/tir/scope_handler.py b/python/tvm/script/tir/scope_handler.py index 2e1d5b605913..7d3250fe8711 100644 --- a/python/tvm/script/tir/scope_handler.py +++ b/python/tvm/script/tir/scope_handler.py @@ -26,7 +26,6 @@ from tvm.tir import Stmt, PrimExpr, IterVar, Var, Buffer, BufferRegion, ForKind from .node import BufferSlice -from .utils import buffer_slice_to_region from ..context_maintainer import ContextMaintainer from ..registry import register @@ -327,12 +326,10 @@ def block(name_hint: str = "", span: Optional[Span] = None): # create block read/write regions reads: List[BufferRegion] = ( - [buffer_slice_to_region(read) for read in block_info.reads] - if block_info.reads - else [] + [read.as_buffer_region() for read in block_info.reads] if block_info.reads else [] ) writes: List[BufferRegion] = ( - [buffer_slice_to_region(write) for write in block_info.writes] + [write.as_buffer_region() for write in block_info.writes] if block_info.writes else [] ) diff --git a/python/tvm/script/tir/special_stmt.py b/python/tvm/script/tir/special_stmt.py index 39a345de7f1e..15502055b7fc 100644 --- a/python/tvm/script/tir/special_stmt.py +++ b/python/tvm/script/tir/special_stmt.py @@ -30,7 +30,6 @@ from tvm.tir import IntImm, IterVar, Var from .node import BufferSlice -from .utils import buffer_slice_to_region from ..context_maintainer import BlockInfo, ContextMaintainer from ..registry import register @@ -168,7 +167,7 @@ def match_buffer( ) self.context.func_buffer_map[param] = buffer elif isinstance(param, BufferSlice): - buffer_region = buffer_slice_to_region(param) + buffer_region = param.as_buffer_region() self.context.current_block_scope().match_buffers.append( tvm.tir.MatchBufferRegion(buffer, buffer_region) ) diff --git a/python/tvm/script/tir/utils.py b/python/tvm/script/tir/utils.py deleted file mode 100644 index e106dab636a1..000000000000 --- a/python/tvm/script/tir/utils.py +++ /dev/null @@ -1,55 +0,0 @@ -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. -"""Helper functions in TVM Script Parser""" - -from typing import List, Optional - -from tvm.arith import Analyzer -from tvm.ir import Range -from tvm.tir import PrimExpr, BufferRegion -from tvm.tir.expr import IntImm -from .node import BufferSlice - - -def buffer_slice_to_region( - buffer_slice: BufferSlice, analyzer: Optional[Analyzer] = None -) -> BufferRegion: - """Construct BufferRegion from BufferSlice - - Parameters - ---------- - buffer_slice : BufferSlice - The input BufferSlice - - analyzer : Optional[tvm.arith.Analyzer] - The analyzer for simplifying. If not provided, the method will construct a new one - - Returns - ------- - buffer_region : BufferRegion - The constructed BufferRegion. - """ - region: List[Range] = [] - for s in buffer_slice.slices: - start = s.start if isinstance(s.start, PrimExpr) else IntImm("int32", s.start) - extent = IntImm(start.dtype, 1) if s.stop is None else s.stop - s.start - if not analyzer: - analyzer = Analyzer() - if isinstance(extent, PrimExpr): - extent = analyzer.simplify(extent) - region.append(Range.from_min_extent(start, extent, span=s.span)) - return BufferRegion(buffer_slice.buffer, region) diff --git a/src/printer/tvmscript_printer.cc b/src/printer/tvmscript_printer.cc index 6f8d10b32040..99d1a7845d3f 100644 --- a/src/printer/tvmscript_printer.cc +++ b/src/printer/tvmscript_printer.cc @@ -265,6 +265,7 @@ class TVMScriptPrinter : public StmtFunctor, Doc PrintRange(const RangeNode* op); Doc PrintArray(const ArrayNode* op); Doc PrintBuffer(const BufferNode* op); + Doc PrintBufferIndices(const Array& indices); Doc PrintNonHeaderBufferDeclarations(const Array& aliasing_buffers); Doc AllocBufferDeclaration(const Buffer& buf); Doc PrintBlockVar(const IterVar& iter_var, const PrimExpr& value); @@ -834,7 +835,7 @@ Doc TVMScriptPrinter::VisitExpr_(const BufferLoadNode* op, ExprPrecedence* out_p if (op->indices.size() == 0) { doc << Print(op->buffer) << "[()]"; } else { - doc << Print(op->buffer) << Print(op->indices); + doc << Print(op->buffer) << PrintBufferIndices(op->indices); } return doc; } @@ -1260,7 +1261,7 @@ Doc TVMScriptPrinter::VisitStmt_(const BufferStoreNode* op) { if (op->indices.size() == 0) { doc << Print(op->buffer) << "[()] = " << Print(op->value); } else { - doc << Print(op->buffer) << Print(op->indices) << " = " << Print(op->value); + doc << Print(op->buffer) << PrintBufferIndices(op->indices) << " = " << Print(op->value); } return doc; } @@ -1678,6 +1679,30 @@ Doc TVMScriptPrinter::PrintBuffer(const BufferNode* op) { return meta_.InMeta(buffer) ? meta_.GetMetaNode(buffer) : AllocBuf(buffer); } +Doc TVMScriptPrinter::PrintBufferIndices(const Array& indices) { + Doc doc; + doc << '['; + for (size_t i = 0; i < indices.size(); ++i) { + if (i != 0) { + doc << ", "; + } + PrimExpr index = indices[i]; + if (const RampNode* ramp = index.as()) { + // specify ramp printing as python index slice + if (auto* stride_imm = ramp->stride.as()) { + doc << Print(ramp->base) << ":" << Print(ramp->base + ramp->lanes * ramp->stride); + if (stride_imm->value != 1) { + doc << ":" << Print(ramp->stride); + } + continue; + } + } + doc << Print(index); + } + doc << ']'; + return doc; +} + Doc TVMScriptPrinter::PrintNonHeaderBufferDeclarations(const Array& aliasing_buffers) { Doc decls; for (const auto& buf_usage : aliasing_buffers) { diff --git a/tests/python/unittest/test_tvmscript_error_report.py b/tests/python/unittest/test_tvmscript_error_report.py index 0610559a05d8..9d5d54594d45 100644 --- a/tests/python/unittest/test_tvmscript_error_report.py +++ b/tests/python/unittest/test_tvmscript_error_report.py @@ -372,29 +372,6 @@ def test_error_index_type(): check_error(error_bufferslice_index_type, 8) -def error_index_with_stop() -> None: - A = T.alloc_buffer((128, 128), "float32") - for i, j in T.grid(128, 128): - with T.block(): - vi, vj = T.axis.remap("SS", [i, j]) - A[vi, vj] = A[vi, 1:10] + 1 # error - - -def error_bufferslice_index_with_stop() -> None: - A = T.alloc_buffer((1,), "int32") - B = T.alloc_buffer((16, 16), "float32") - C = T.alloc_buffer((16, 16), "float32") - for i, j in T.grid(16, 16): - with T.block(): - vi, vj = T.axis.remap("SS", [i, j]) - C[vi, vj] = B[vi, A[0:1]] # error - - -def test_error_index_with_stop_slice(): - check_error(error_index_with_stop, 6) - check_error(error_bufferslice_index_with_stop, 8) - - def special_stmt_except() -> None: A = T.alloc_buffer("(128, 128)", "float32") # error T.evaluate(1.0) @@ -658,5 +635,42 @@ def test_preflattened_buffer_map_offset_factor(): check_error(preflattened_buffer_map_offset_factor_nonint, 3) +def strided_buffer_region(A: T.handle): + # do not allow stride in buffer region + A = T.match_buffer((128, 128), "int32") + with T.block(): + T.reads([]) + T.writes([A[0:128:2, 0:128:3]]) # error + T.evaluate(T.call_extern("strided_compute", dtype="")) + + +def strided_buffer_region(A: T.handle): + # do not allow stride in buffer region + A = T.match_buffer((128, 128), "int32") + with T.block(): + T.reads([]) + T.writes([A[0:128:2, 0:128:3]]) # error + T.evaluate(T.call_extern("strided_compute", dtype="")) + + +def access_reversed_slice(A: T.handle): + # do not allow reversed slice step + A = T.match_buffer((128,), "int32") + A[0:128:-1] = T.broadcast(1, 128) + + +def access_non_const_slice(A: T.handle): + # do not allow reversed slice step + A = T.match_buffer((128,), "int32") + for i in range(4): + T.evaluate(A[0:i:1]) + + +def test_illegal_buffer_slice(): + check_error(strided_buffer_region, 3) + check_error(access_reversed_slice, 3) + check_error(access_non_const_slice, 3) + + if __name__ == "__main__": sys.exit(pytest.main([__file__] + sys.argv[1:])) diff --git a/tests/python/unittest/test_tvmscript_roundtrip.py b/tests/python/unittest/test_tvmscript_roundtrip.py index c704baebc7e1..948a76216831 100644 --- a/tests/python/unittest/test_tvmscript_roundtrip.py +++ b/tests/python/unittest/test_tvmscript_roundtrip.py @@ -3272,6 +3272,22 @@ def element_wise(a: T.handle, c: T.handle) -> None: return element_wise +def buffer_ramp_access_as_slice_index(): + @T.prim_func + def buffer_ramp_access(a: T.handle, b: T.handle, c: T.handle) -> None: + A = T.match_buffer(a, (128,), "float32") + B = T.match_buffer(b, (128,), "float32") + C = T.match_buffer(c, (128,), "float32") + for i in range(128): + A[i : i + 1 : 1] = i + for i in range(4): + B[i * 32 : i * 32 + 32] = A[i * 32 : i * 32 + 32 : 1] + T.broadcast(1.0, 32) + for i in range(4): + C[i : i + 128 : 4] = B[i : i + 128 : 4] + T.broadcast(1.0, 32) + + return buffer_ramp_access + + ir_generator = tvm.testing.parameter( opt_gemm_normalize, opt_gemm_lower, @@ -3308,6 +3324,7 @@ def element_wise(a: T.handle, c: T.handle) -> None: string_annotation_escaping, pointer_type, buffer_axis_separator, + buffer_ramp_access_as_slice_index, ) From 3a30d1d071ecbc573a1515beed2cd731880447dc Mon Sep 17 00:00:00 2001 From: wrongtest Date: Sat, 14 May 2022 00:48:19 +0800 Subject: [PATCH 2/2] fix testcase's comment, check slice lanes instead of extent --- python/tvm/script/tir/node.py | 9 ++++----- .../unittest/test_tvmscript_error_report.py | 19 +++++-------------- 2 files changed, 9 insertions(+), 19 deletions(-) diff --git a/python/tvm/script/tir/node.py b/python/tvm/script/tir/node.py index 3779df6c3b05..29e79607fbc9 100644 --- a/python/tvm/script/tir/node.py +++ b/python/tvm/script/tir/node.py @@ -70,15 +70,14 @@ def as_index_expr(self, report_error: Callable[[str, Union[Span, synr.ast.Span]] if self.stop is None: # scalar index return self.start - extent = Analyzer().simplify(self.stop - self.start) - if not isinstance(extent, (int, IntImm)): - report_error("Slice's extent should be constant for buffer indices", self.span) if self.step < 1: report_error("Slice's step should be positive integer", self.span) - lanes = (int(extent) + self.step - 1) // self.step + lanes = Analyzer().simplify((self.stop - self.start + self.step - 1) // self.step) + if not isinstance(lanes, (int, IntImm)): + report_error("Slice's lanes should be constant for buffer indices", self.span) if lanes == 1: return self.start - return Ramp(self.start, self.step, lanes, self.span) + return Ramp(self.start, self.step, int(lanes), self.span) class BufferSlice(ObjectGeneric): diff --git a/tests/python/unittest/test_tvmscript_error_report.py b/tests/python/unittest/test_tvmscript_error_report.py index 9d5d54594d45..070b5e85f174 100644 --- a/tests/python/unittest/test_tvmscript_error_report.py +++ b/tests/python/unittest/test_tvmscript_error_report.py @@ -644,32 +644,23 @@ def strided_buffer_region(A: T.handle): T.evaluate(T.call_extern("strided_compute", dtype="")) -def strided_buffer_region(A: T.handle): - # do not allow stride in buffer region - A = T.match_buffer((128, 128), "int32") - with T.block(): - T.reads([]) - T.writes([A[0:128:2, 0:128:3]]) # error - T.evaluate(T.call_extern("strided_compute", dtype="")) - - def access_reversed_slice(A: T.handle): # do not allow reversed slice step A = T.match_buffer((128,), "int32") - A[0:128:-1] = T.broadcast(1, 128) + A[0:128:-1] = T.broadcast(1, 128) # error -def access_non_const_slice(A: T.handle): - # do not allow reversed slice step +def access_non_const_slice_length(A: T.handle): + # do not allow non-constant slice length A = T.match_buffer((128,), "int32") for i in range(4): - T.evaluate(A[0:i:1]) + T.evaluate(A[0:i:1]) # error def test_illegal_buffer_slice(): check_error(strided_buffer_region, 3) check_error(access_reversed_slice, 3) - check_error(access_non_const_slice, 3) + check_error(access_non_const_slice_length, 3) if __name__ == "__main__":