Skip to content

Commit

Permalink
[TensorIR] Support for match_buffer from subregion (#8585)
Browse files Browse the repository at this point in the history
Co-authored-by: Junru Shao <junrushao1994@gmail.com>
Co-authored-by: Bohan Hou <32121147+spectrometerHBH@users.noreply.github.com>
Co-authored-by: Ruihang Lai <lairuihangdongdong@qq.com>
Co-authored-by: Hongyi Jin <3231950289@qq.com>
Co-authored-by: Wuwei Lin <wuwei@apache.org>
  • Loading branch information
6 people authored Jul 31, 2021
1 parent 5012462 commit 2a8950b
Show file tree
Hide file tree
Showing 29 changed files with 1,670 additions and 161 deletions.
6 changes: 6 additions & 0 deletions include/tvm/tir/transform.h
Original file line number Diff line number Diff line change
Expand Up @@ -423,6 +423,12 @@ TVM_DLL Pass CompactBufferAllocation();
*/
TVM_DLL Pass LegalizePackedCalls();

/*!
* \brief Remove match buffers inside the block. Also, it will validate the binding.
* \return The pass.
*/
TVM_DLL Pass LowerMatchBuffer();

/*!
* \brief Flatten the multi-dimensional BufferLoad and BufferStore
* to single dimensional Load/Store. Also remove Block to
Expand Down
6 changes: 3 additions & 3 deletions python/tvm/script/context_maintainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,21 +57,21 @@ def example_func(a: ty.handle, b: ty.handle, c: ty.handle) -> None:
# match_buffers of the block,
# which bind a sub-region of source buffer into a new buffer
D = tir.match_buffer_region(C[vi, vj])
D = tir.match_buffer(C[vi, vj], ())
# init part of the block, executed when all reduce axes are the beginning value
with tir.init():
C[vi, vj] = tir.float32(0)
# block body
CC[0, 0] = A[vi, vk] * B[vj, vk]
D[0, 0] += CC[0, 0] # The same as C[vi, vj] += CC[0, 0]
D[()] += CC[0, 0] # The same as C[vi, vj] += CC[0, 0]
"""

alloc_buffers: List[Buffer] = []
"""List[Buffer]: list of tir.alloc_buffer statements in the block signature"""
match_buffers: List[MatchBufferRegion] = []
"""List[MatchBufferRegion]: list of tir.match_buffer_region statements in the block signature"""
"""List[MatchBufferRegion]: list of tir.match_buffer statements in the block signature"""
iter_bindings: Mapping[Var, PrimExpr] = {}
"""Mapping[Var, PrimExpr]: map of block iter var to its values"""
reads: Optional[List[BufferSlice]] = None
Expand Down
22 changes: 21 additions & 1 deletion python/tvm/script/parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -784,10 +784,11 @@ def transform_Slice(self, node):
def transform_Subscript(self, node):
"""Array access visitor.
By now only 2 types of Subscript are supported:
By now only 3 types of Subscript are supported:
1. Buffer[index, index, ...], Buffer element access(BufferLoad & BufferStore)
Var[index] Buffer element access()
2. Buffer[start: stop, start: stop, ...], BufferRealize(realize(buffer[...]))
3. Array[index], Buffer element access
"""

symbol = self.transform(node.params[0])
Expand All @@ -812,6 +813,25 @@ def transform_Subscript(self, node):
return BufferSlice(
symbol, indexes, self.report_error, span=tvm_span_from_synr(node.span)
)
elif isinstance(symbol, tvm.container.Array):
if len(indexes) > 1:
self.report_error(
"Array access should be one-dimension access, but the indices are "
+ str(indexes),
node.span,
)
index = indexes[0]
if not isinstance(index, (int, tvm.tir.expr.IntImm)):
self.report_error(
"Array access index expected int or IntImm, but got " + type(index),
node.span,
)
if int(index) >= len(symbol):
self.report_error(
f"Array access out of bound, size: {len(symbol)}, got index {index}.",
node.span,
)
return symbol[int(index)]
else:
self.report_error(
f"Cannot subscript from a {type(symbol).__name__}. Only variables and "
Expand Down
98 changes: 30 additions & 68 deletions python/tvm/script/special_stmt.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,12 +96,24 @@ def handle(

@register
class MatchBuffer(SpecialStmt):
"""Special Stmt match_buffer(var, shape, dtype, data, strides, elem_offset, scope, align,
"""Special Stmt match_buffer(param, shape, dtype, data, strides, elem_offset, scope, align,
offset_factor, buffer_type)
Note
----
This Special Stmt will perform different behavior depends on the type of param.
If the param is a var in function parameter, it will create a buffer from DLTensor.
Else if the param is a subregion of other buffers, then create a subregion match inside a block.
Example
-------
Match buffer from function parameter
.. code-block:: python
A = tir.match_buffer(a, (128, 128), dtype="float32")
Match buffer from Buffer subregion
.. code-block:: python
A = tir.match_buffer(B[0:128, i * 128 : i * 128 + 128], (128, 128), dtype="float32")
"""

def __init__(self):
Expand All @@ -123,10 +135,6 @@ def match_buffer(
"match_buffer must be assigned to a buffer, e.g. A = match_buffer(...)",
self.node.span,
)
if param not in self.context.func_params:
self.context.report_error(
"Can not bind non-input param to buffer", self.node.rhs.params[0].span
)
if strides is None:
strides = []
align = convert_to_int(align, "align", self.context.report_error, self.node.span)
Expand All @@ -146,7 +154,23 @@ def match_buffer(
buffer_type,
span=span,
)
self.context.func_buffer_map[param] = buffer
if isinstance(param, tvm.tir.Var):
if param not in self.context.func_params:
self.context.report_error(
"Can not bind non-input param to buffer", self.node.rhs.params[0].span
)
self.context.func_buffer_map[param] = buffer
elif isinstance(param, BufferSlice):
buffer_region = buffer_slice_to_region(param)
self.context.current_block_scope().match_buffers.append(
tvm.tir.MatchBufferRegion(buffer, buffer_region)
)
else:
self.context.report_error(
"The source of match_buffer expected Var or BufferSlice, but got "
+ str(type(param)),
self.node.rhs.params[0].span,
)
self.context.update_symbol(self.node.lhs.id.name, buffer, self.node)

super().__init__(match_buffer, def_symbol=True)
Expand Down Expand Up @@ -414,68 +438,6 @@ def where(predicate, span=None):
super().__init__(where, def_symbol=False)


@register
class BlockMatchBufferRegion(SpecialStmt):
"""Special function match_buffer_region(source, strides, elem_offset, align, offset_factor)
Example
-------
.. code-block:: python
B = tir.match_buffer_region(A[0: 4])
"""

def __init__(self):
def match_buffer_region(
source,
strides=None,
elem_offset=None,
align=-1,
offset_factor=0,
span=None,
):
assert self.context, "call 'exit_scope' before 'enter_scope'"
if not isinstance(self.node, ast.Assign):
self.context.report_error(
"match_buffer_region must be assigned to a buffer, "
+ "e.g. A = match_buffer_region(...)",
self.node.span,
)

if strides is None:
strides = []
align = convert_to_int(align, "align", self.context.report_error, self.node.span)
offset_factor = convert_to_int(
offset_factor, "offset_factor", self.context.report_error, self.node.span
)

if not isinstance(source, BufferSlice):
self.context.report_error(
"match_buffer_region needs a buffer region as source",
span=span,
)
buffer_region = buffer_slice_to_region(source)
shape = [r.extent for r in buffer_region.region]
buffer = tvm.tir.decl_buffer(
shape,
buffer_region.buffer.dtype,
self.node.lhs.id.name,
data=None,
strides=strides,
elem_offset=elem_offset,
scope=buffer_region.buffer.scope(),
data_alignment=align,
offset_factor=offset_factor,
span=span,
)
self.context.current_block_scope().match_buffers.append(
tvm.tir.MatchBufferRegion(buffer, buffer_region)
)
self.context.update_symbol(self.node.lhs.id.name, buffer, self.node)

super().__init__(match_buffer_region, def_symbol=True)


@register
class VarDef(SpecialStmt):
"""Special function for defining a Var"""
Expand Down
2 changes: 1 addition & 1 deletion python/tvm/tir/buffer.py
Original file line number Diff line number Diff line change
Expand Up @@ -255,7 +255,7 @@ def decl_buffer(
dtype = "float32" if dtype is None else dtype
strides = () if strides is None else strides
if offset_factor != 0 and elem_offset is None:
shape_dtype = shape[0].dtype if hasattr(shape[0], "dtype") else "int32"
shape_dtype = shape[0].dtype if shape and hasattr(shape[0], "dtype") else "int32"
elem_offset = Var("%s_elem_offset" % name, shape_dtype)
if data is None:
# Bool is represented as uint1 in the IR, but stored as int8
Expand Down
11 changes: 11 additions & 0 deletions python/tvm/tir/transform/transform.py
Original file line number Diff line number Diff line change
Expand Up @@ -644,6 +644,17 @@ def CompactBufferAllocation():
return _ffi_api.CompactBufferAllocation() # type: ignore


def LowerMatchBuffer():
"""Remove match buffers inside the block. Also, it will validate the binding.
Returns
-------
fpass : tvm.transform.Pass
The result pass
"""
return _ffi_api.LowerMatchBuffer() # type: ignore


def FlattenBuffer():
"""Flatten the multi-dimensional BufferLoad and BufferStore
to single dimensional Load/Store. Also remove Block to
Expand Down
1 change: 1 addition & 0 deletions src/driver/driver_api.cc
Original file line number Diff line number Diff line change
Expand Up @@ -222,6 +222,7 @@ Array<tvm::transform::Pass> CreatePassList(bool disable_loop_partition, bool for
pass_list.push_back(tir::transform::PlanAndUpdateBufferAllocationLocation());
pass_list.push_back(tir::transform::ConvertBlocksToOpaque());
pass_list.push_back(tir::transform::CompactBufferAllocation());
pass_list.push_back(tir::transform::LowerMatchBuffer());
pass_list.push_back(tir::transform::FlattenBuffer());
}
pass_list.push_back(tir::transform::BF16Legalize());
Expand Down
4 changes: 2 additions & 2 deletions src/printer/tir_text_printer.cc
Original file line number Diff line number Diff line change
Expand Up @@ -583,8 +583,8 @@ Doc TIRTextPrinter::VisitStmt_(const BlockRealizeNode* op) {
<< Print(alloc_buf->shape) << ")" << Doc::NewLine();
}
for (const auto& match_buf : block_op->match_buffers) {
body << AllocBuf(match_buf->buffer) << " = match_buffer_region(" << Print(match_buf->source)
<< ")" << Doc::NewLine();
body << AllocBuf(match_buf->buffer) << " = match_buffer(" << Print(match_buf->source) << ")"
<< Doc::NewLine();
}
if (block_op->init.defined()) {
Doc init_block;
Expand Down
25 changes: 2 additions & 23 deletions src/printer/tvmscript_printer.cc
Original file line number Diff line number Diff line change
Expand Up @@ -337,29 +337,8 @@ Doc TVMScriptPrinter::PrintMatchBufferRegion(const MatchBufferRegionNode* op) {
const Buffer& buf = op->buffer;
buf_not_in_headers.insert(buf.get());

Doc doc = Print(op->buffer) << " = tir.match_buffer_region(" << Print(op->source);
if (!buf->strides.empty()) {
doc << ", strides=" << Print(buf->strides);
}
if (buf->offset_factor != 0 && buf->elem_offset->IsInstance<VarNode>()) {
Var elem_offset = Downcast<Var>(buf->elem_offset);
if (memo_var_.find(elem_offset) != memo_var_.end()) {
doc << ", elem_offset=" << Print(buf->elem_offset);
} else {
// implicitly define elem_offset
memo_var_[elem_offset] = Doc::Text(memo_buf_[buf].str() + ".elem_offset");
var_not_in_headers.insert(elem_offset.get());
}
} else {
doc << ", elem_offset=" << Print(buf->elem_offset);
}
if (buf->data_alignment != -1) {
doc << ", align=" << buf->data_alignment;
}
if (buf->offset_factor != 0) {
doc << ", offset_factor=" << buf->offset_factor;
}
doc << ")";
Doc doc = Print(op->buffer) << " = tir.match_buffer(" << Print(op->source) << ", "
<< memo_buf_decl_[op->buffer] << ")";
return doc;
}

Expand Down
Loading

0 comments on commit 2a8950b

Please sign in to comment.