Skip to content

Commit

Permalink
[TIR] tir.transform.StorageFlatten refactor (#9091)
Browse files Browse the repository at this point in the history
* [TE] Improved flexibility of ArgBinder::BindDLTensor

Allowed a compact DLTensor to bind to a Buffer object that defines
strides, if the strides defined correspond to a compact layout.

* [TIR] Exposed ElemOffset as a member function of BufferNode.

* [TE] Pulled shape determination out of StorageFlattener

Previously, StorageFlattener would determine the shape of a physical
buffer based on the extents of the BufferRealizeNode.  Pulled these
out into a separate BufferShapeLegalize pass.  After this pass, all
buffers have a shape that matches the buffer realization extents.

* [TE] Refactor stride calculation out of StorageFlattener

Previously, StorageFlattener would handle any attr::dim_align
annotations.  Now, this is pulled out into a separate
BufferStrideLegalize pass.

* [TE] Refactor thread scope propagation out of StorageFlattener.

Previously, StorageFlattener would use the scope in IterVar to assign
a scope to allocated buffers, where not otherwise defined.  This has
been pulled out into a separate ThreadScopePropagate pass.

* [TE] Refactor buffer bind mapping out of StorageFlattener.

Previously, StorageFlattener would look for `attr::buffer_bind_scope`
to determine if a Buffer object is a view into another buffer, and
would apply that mapping while making the Allocate/Store/Load nodes.
Now, the mapping of buffer binds is pulled out into a separate
BufferStrideUnwrapper pass.

This also resolves an issue in which BufferLoad/BufferStore nodes that
refer to a Buffer defined through `attr::buffer_bind_scope` would
generate Load/Store nodes that point to the linked buffer, rather than
the actual buffer.

* [TIR] Removed checks on buffer->shape.size()

Even after BufferShapeLegalize, rank-zero tensors may have an empty
shape.

* [TIR] Relaxed check on a bufferview's striding.

Original refactoring requiring that a bufferview have no explicit
striding, and instead take the striding from the buffer that it is
viewing.  Modified to allow bufferview to specify striding, so long as
it is consistent with the viewed buffer's striding.  This reproduces
the behavior of StorageFlatten before the refactoring.

* [TIR] Fixed StorageFlatten test for shape_legalize.

AttrStmtNodes that contain rewritten Buffers need to be rewritten as
well.

* [TIR] Assigned storage scope

The earlier stage of the refactor left a buffer's storage scope
undefined if it's scope was not determined by the IterVar of a loop
containing its allocation.  Now, these are explicitly set to
StorageScope::kGlobal, to match the previous behavior of
StorageFlatten.

* Updated ICHECK_EQ to CHECK_EQ for a test that depends on user-provided
data.

* Added comments in storage_flatten.cc, indicating why buffer_bind_scope
needs special handling.

* Updated comment with a few examples of where compact buffers are
assumed to have no strides defined.

* Updated following @csullivan's comments.

* Added fuzzy mapping to the BufferShapeLegalize.

Maintains earlier behavior of StorageFlatten, which allows buffer
views to be mapped to higher dimension buffers, if the view extent is
1 in each extra dimension.

* Updated BufferShapeLegalize, asserts need to be inside the buffer_bind_scope.

* Pulled all shape-dependent behavior into BufferShapeLegalize.

Previously, BufferBindUnwrapper passed fuzzy_match=true to
ArgBinder::BindBuffer, which could change the number of dimensions.
Now, all buffer dimensions should be updated prior to
BufferBindUnwrapper, and it is an error to have mismatched dimensions
in BufferBindUnwrapper.

* Added another pass to remove verifiable assert statements.

ArgBinder::BindBuffer inserts these assert statements if they are not
verifiable at the time of substitution.  Previously, with one giant
substitution, the assertions were verifiable at that time.  After the
refactor, with substitutions done in multiple stages for
shape/stride/buffer_bind_scope, we need to clean up any assertions
that are verifiable after all substitutions have occurred.

* Minor cleanup

- Removed StorageFlattener::BufferEntry::RelIndex, behavior already
  handled by BufferShapeLegalize.

- Improved comments and error messages.

- Extracted duplicate behavior in BufferLoad/BufferStore handling in
  BufferShapeLegalize.

* Updated to handle BufferRealizeNode with no defined bounds.

* Updated to be less aggressive when checking AssertStmt

A true Assert statement can be removed, but a false Assert statement
requires CFA to give as a compile-time error.  Since we only need the
removal of true assert statements, skipping the CFA this time.
  • Loading branch information
Lunderberg authored Oct 1, 2021
1 parent 4b4b3d0 commit 659f3b7
Show file tree
Hide file tree
Showing 4 changed files with 1,136 additions and 200 deletions.
8 changes: 8 additions & 0 deletions include/tvm/tir/buffer.h
Original file line number Diff line number Diff line change
Expand Up @@ -121,6 +121,14 @@ class BufferNode : public Object {
return shape.size() != 0 ? shape[0].dtype() : DataType::Int(32);
}

/*! \brief Determine the offset in the buffer of the given index.
*
* Returns the buffer offset, in number of elements of type dtype,
* without adjusting for number of lanes. (e.g. The number of
* float16x4 elements in a buffer of type float16x4.)
*/
PrimExpr ElemOffset(Array<PrimExpr> index) const;

static constexpr const char* _type_key = "tir.Buffer";
static constexpr const bool _type_has_method_sequal_reduce = true;
static constexpr const bool _type_has_method_shash_reduce = true;
Expand Down
24 changes: 12 additions & 12 deletions src/tir/ir/buffer.cc
Original file line number Diff line number Diff line change
Expand Up @@ -246,41 +246,41 @@ inline PrimExpr MergeMulMod(arith::Analyzer* analyzer, const PrimExpr& base) {
// The buffer offset in convention of number of elements of
// original data ignoring number of lanes.
// We also perform optimization to simplify the indexing expression.
inline PrimExpr ElemOffset(const BufferNode* n, Array<PrimExpr> index) {
PrimExpr base = n->elem_offset;
PrimExpr BufferNode::ElemOffset(Array<PrimExpr> index) const {
PrimExpr base = this->elem_offset;
arith::Analyzer ana;
if (n->strides.size() == 0) {
if (this->strides.size() == 0) {
// Scalar case
if (n->shape.size() == 0 && index.size() == 1) {
if (this->shape.size() == 0 && index.size() == 1) {
auto is_int = index[0].as<IntImmNode>();
ICHECK(is_int && is_int->value == 0);
base = base + index[0];
} else {
ICHECK_EQ(n->shape.size(), index.size());
ICHECK_EQ(this->shape.size(), index.size());
if (index.size() > 0) {
PrimExpr offset = index[0];
for (size_t i = 1; i < index.size(); ++i) {
offset = MergeMulMod(&ana, offset * n->shape[i] + index[i]);
offset = MergeMulMod(&ana, offset * this->shape[i] + index[i]);
}
base = base + offset;
}
}
} else {
ICHECK_EQ(n->strides.size(), index.size());
ICHECK_EQ(this->strides.size(), index.size());
if (is_zero(base)) {
base = MergeMulMod(&ana, index[0] * n->strides[0]);
base = MergeMulMod(&ana, index[0] * this->strides[0]);
} else {
base = MergeMulMod(&ana, base + index[0] * n->strides[0]);
base = MergeMulMod(&ana, base + index[0] * this->strides[0]);
}
for (size_t i = 1; i < index.size(); ++i) {
base = MergeMulMod(&ana, base + index[i] * n->strides[i]);
base = MergeMulMod(&ana, base + index[i] * this->strides[i]);
}
}
return base;
}

inline PrimExpr BufferOffset(const BufferNode* n, Array<PrimExpr> index, DataType dtype) {
PrimExpr offset = ElemOffset(n, index);
PrimExpr offset = n->ElemOffset(index);
if (n->dtype.lanes() != 1) {
offset = offset * make_const(offset.dtype(), dtype.lanes());
}
Expand Down Expand Up @@ -353,7 +353,7 @@ Buffer Buffer::MakeSlice(Array<PrimExpr> begins, Array<PrimExpr> extents) const
ICHECK(n != nullptr);
arith::Analyzer ana;
begins = SimplifyArray(&ana, begins);
PrimExpr elem_offset = ana.Simplify(ElemOffset(n, begins));
PrimExpr elem_offset = ana.Simplify(n->ElemOffset(begins));
Array<PrimExpr> strides = n->strides;
if (strides.size() == 0) {
bool can_relax = true;
Expand Down
25 changes: 15 additions & 10 deletions src/tir/transforms/arg_binder.cc
Original file line number Diff line number Diff line change
Expand Up @@ -204,7 +204,7 @@ void ArgBinder::BindDLTensor(const Buffer& buffer, const PrimExpr& device_type,
def_handle_dtype_.Set(v_strides, tir::TypeAnnotation(tvm_shape_type));
init_nest_.emplace_back(
LetStmt(v_strides, TVMArrayGet(DataType::Handle(), handle, builtin::kArrStrides), nop));
PrimExpr is_null = Call(DataType::Bool(1), builtin::isnullptr(), {v_strides});
PrimExpr v_strides_is_null = Call(DataType::Bool(1), builtin::isnullptr(), {v_strides});
if (buffer->strides.size() == 0) {
// Assert the buffer is compact
DataType stype = buffer->DefaultIndexType();
Expand All @@ -226,7 +226,7 @@ void ArgBinder::BindDLTensor(const Buffer& buffer, const PrimExpr& device_type,
foldl([](PrimExpr a, PrimExpr b, Span span) { return logical_and(a, b, span); },
const_true(1), conds),
stride_msg, Evaluate(0));
check = IfThenElse(Not(is_null), check, Stmt());
check = IfThenElse(Not(v_strides_is_null), check, Stmt());
asserts_.emplace_back(SeqStmt({check, Evaluate(0)}));
}
} else if (buffer->buffer_type == kAutoBroadcast) {
Expand All @@ -239,24 +239,29 @@ void ArgBinder::BindDLTensor(const Buffer& buffer, const PrimExpr& device_type,
PrimExpr value =
cast(buffer->shape[k].dtype(),
Load(tvm_shape_type, v_strides, IntImm(DataType::Int(32), k), const_true(1)));
value = tvm::if_then_else(is_null, stride, value);
value = tvm::if_then_else(v_strides_is_null, stride, value);
value = tvm::if_then_else(buffer->shape[k] == 1, 0, value);
Bind_(buffer->strides[k], value, field_name.str(), true);
stride = analyzer_.Simplify(stride * buffer->shape[k]);
}
} else {
std::ostringstream stride_null_err_msg;
stride_null_err_msg << arg_name << ".strides: expected non-null strides.";
asserts_.emplace_back(
AssertStmt(Not(is_null), tvm::tir::StringImm(stride_null_err_msg.str()), nop));
PrimExpr stride_from_shape = 1;

for (size_t k = 0; k < buffer->strides.size(); ++k) {
for (int k = buffer->strides.size() - 1; k >= 0; k--) {
std::ostringstream field_name;
field_name << v_strides->name_hint << '[' << k << ']';

PrimExpr explicit_stride =
cast(buffer->shape[k].dtype(),
Load(tvm_shape_type, v_strides, IntImm(DataType::Int(32), k), const_true(1)));

Bind_(buffer->strides[k],
cast(buffer->shape[k].dtype(),
Load(tvm_shape_type, v_strides, IntImm(DataType::Int(32), k), const_true(1))),
tvm::if_then_else(v_strides_is_null, stride_from_shape, explicit_stride),
field_name.str(), true);

stride_from_shape *=
cast(buffer->shape[k].dtype(),
Load(tvm_shape_type, v_shape, IntImm(DataType::Int(32), k), const_true(1)));
}
}
// Byte_offset field.
Expand Down
Loading

0 comments on commit 659f3b7

Please sign in to comment.