Skip to content

Commit

Permalink
[TIR] Add tir::builtin::undef (#12266)
Browse files Browse the repository at this point in the history
* [UnitTest] RemoveStoreUndef, simplest behavior

* [RemoveStoreUndef] First implementation

* [UnitTest] RemoveStoreUndef, stores that depend through LetStmt

* [UnitTest] RemoveStoreUndef, LetStmt handling, error on illegal usage

* [RemoveStoreUndef] Added error checking for illegal T.undef() usage

* Fix lint error

* Use const ref for list of stores to remove

* Verify that removed expression has no other side effects

* Fix lint error
  • Loading branch information
Lunderberg authored Aug 6, 2022
1 parent 2a7af61 commit c4aab62
Show file tree
Hide file tree
Showing 5 changed files with 296 additions and 0 deletions.
8 changes: 8 additions & 0 deletions include/tvm/tir/builtin.h
Original file line number Diff line number Diff line change
Expand Up @@ -729,6 +729,14 @@ TVM_DLL const Op& mem_copy();
*/
TVM_DLL const Op& assume();

/*!
* \brief Returns an initialized but arbitrary value
*
* Compile-time representation of memory locations whose values may be
* altered as a result of optimizations.
*/
TVM_DLL const Op& undef();

/*! \brief The kind of structure field info used in intrinsic */
enum TVMStructFieldKind : int {
// array head address
Expand Down
11 changes: 11 additions & 0 deletions python/tvm/tir/transform/transform.py
Original file line number Diff line number Diff line change
Expand Up @@ -264,6 +264,17 @@ def RemoveAssume():
return _ffi_api.RemoveAssume() # type: ignore


def RemoveStoreUndef():
"""Remove stores of undefined values from the Stmt.
Returns
-------
fpass : tvm.transform.Pass
The result pass
"""
return _ffi_api.RemoveStoreUndef() # type: ignore


def BF16Legalize():
"""Legalize bf16 typed Ops.
Runs BF16Promote, BF16CastElimination and BF16TypeLowering
Expand Down
4 changes: 4 additions & 0 deletions src/tir/op/builtin.cc
Original file line number Diff line number Diff line change
Expand Up @@ -292,6 +292,10 @@ TIR_DEFINE_BUILTIN_FUNC(assume)
.set_attr<TCallEffectKind>("TCallEffectKind", Integer(CallEffectKind::kEmbedInfo))
.set_num_inputs(1);

TIR_DEFINE_BUILTIN_FUNC(undef)
.set_attr<TCallEffectKind>("TCallEffectKind", Integer(CallEffectKind::kReadState))
.set_num_inputs(0);

} // namespace builtin
} // namespace tir
} // namespace tvm
179 changes: 179 additions & 0 deletions src/tir/transforms/remove_store_undef.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,179 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/

/*!
* \file remove_store_undef.cc
* \brief Remove stores of tir::builtin::undef
*/
#include <tvm/runtime/registry.h>
#include <tvm/tir/analysis.h>
#include <tvm/tir/builtin.h>
#include <tvm/tir/op.h>
#include <tvm/tir/stmt.h>
#include <tvm/tir/stmt_functor.h>
#include <tvm/tir/transform.h>

namespace tvm {
namespace tir {

class StoreUndefLocator : public StmtExprVisitor {
public:
static std::unordered_set<const BufferStoreNode*> Locate(Stmt stmt) {
StoreUndefLocator locator;
locator(std::move(stmt));
return locator.undef_stores_;
}

private:
StoreUndefLocator() = default;

void VisitStmt_(const BufferStoreNode* op) final {
bool stash_undef = false;
std::swap(has_undef_, stash_undef);
StmtExprVisitor::VisitExpr(op->value);
std::swap(has_undef_, stash_undef);
if (stash_undef) {
ICHECK(SideEffect(op->value) <= CallEffectKind::kReadState)
<< "Error: T.undef() used in BufferStore expressions "
<< "must not have other side effects";
undef_stores_.insert(op);
}
}

void VisitExpr_(const BufferLoadNode* op) final {
// This function left deliberately empty. builtin::undef()
// shouldn't occur in the indices of BufferLoad. Avoiding
// visiting the indices catches the builtin::undef in
// ValidateAllUndefRemoved.
}

void VisitStmt_(const LetStmtNode* op) final {
bool stash_undef = false;
std::swap(has_undef_, stash_undef);
StmtExprVisitor::VisitExpr(op->value);
std::swap(has_undef_, stash_undef);
if (stash_undef) {
ICHECK(SideEffect(op->value) <= CallEffectKind::kReadState)
<< "Error: T.undef() used in Let expressions "
<< "must not have other side effects";
var_bindings_with_undef_.insert(op->var.get());
}

StmtExprVisitor::VisitStmt(op->body);
}

void VisitExpr_(const VarNode* op) final {
if (var_bindings_with_undef_.count(op)) {
has_undef_ = true;
}
}

void VisitExpr_(const CallNode* op) final {
if (op->op.same_as(builtin::undef())) {
has_undef_ = true;
}
StmtExprVisitor::VisitExpr_(op);
}

bool has_undef_{false};

std::unordered_set<const VarNode*> var_bindings_with_undef_;
std::unordered_set<const BufferStoreNode*> undef_stores_;
};

// Remove any BufferStores whose value depends on T.undef
class StoreUndefRemover : public StmtExprMutator {
public:
static Stmt Apply(Stmt stmt) {
auto to_remove = StoreUndefLocator::Locate(stmt);
StoreUndefRemover mutator(to_remove);
return mutator(std::move(stmt));
}

private:
using Parent = StmtExprMutator;

explicit StoreUndefRemover(const std::unordered_set<const BufferStoreNode*>& to_remove)
: to_remove_(to_remove) {}

Stmt VisitStmt_(const BufferStoreNode* op) final {
if (to_remove_.count(op)) {
return Evaluate(0);
} else {
return Parent::VisitStmt_(op);
}
}

const std::unordered_set<const BufferStoreNode*>& to_remove_;
};

// Remove any BufferStores whose value depends on T.undef
class ContainsUndefChecker : public StmtExprVisitor {
public:
static bool Check(const Stmt& stmt) {
ContainsUndefChecker checker;
checker(stmt);
return checker.contains_undef;
}

private:
void VisitExpr_(const CallNode* op) final {
if (op->op.same_as(builtin::undef())) {
contains_undef = true;
}
StmtExprVisitor::VisitExpr_(op);
}

bool contains_undef{false};
};

namespace transform {
Pass RemoveStoreUndefInternal() {
auto pass_func = [](PrimFunc f, IRModule m, PassContext ctx) {
auto* n = f.CopyOnWrite();
n->body = StoreUndefRemover::Apply(std::move(n->body));
return f;
};
return CreatePrimFuncPass(pass_func, 0, "tir.RemoveStoreUndefInternal", {});
}

Pass ValidateAllUndefRemoved() {
auto pass_func = [](PrimFunc f, IRModule m, PassContext ctx) {
bool contains_undef = ContainsUndefChecker::Check(f->body);
ICHECK(!contains_undef) << "Expected removal of BufferStore containing builtin::undef() "
<< "to remove all instances of builtin::undef(). "
<< "Instead, result was"
<< "\n"
<< f;
return f;
};
return CreatePrimFuncPass(pass_func, 0, "tir.ValidateAllUndefRemoved", {});
}

Pass RemoveStoreUndef() {
return Sequential({RemoveStoreUndefInternal(), RemoveNoOp(), ValidateAllUndefRemoved()},
"tir.RemoveStoreUndef");
}

TVM_REGISTER_GLOBAL("tir.transform.RemoveStoreUndef").set_body_typed(RemoveStoreUndef);

} // namespace transform

} // namespace tir
} // namespace tvm
94 changes: 94 additions & 0 deletions tests/python/unittest/test_tir_transform_remove_undef.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,94 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.

import tvm
import tvm.testing
from tvm.script import tir as T
from tvm import TVMError


class BaseBeforeAfter(tvm.testing.CompareBeforeAfter):
@tvm.testing.fixture
def transform(self):
return tvm.tir.transform.RemoveStoreUndef()


class TestRemoveStoreUndef(BaseBeforeAfter):
"""Remove a store whose value is T.undef()"""

def before(A: T.Buffer[1, "int32"]):
A[0] = T.undef(dtype="int32")

def expected(A: T.Buffer[1, "int32"]):
T.evaluate(0)


class TestRemoveStoreUndefExpression(BaseBeforeAfter):
"""Expressions containing T.undef() are removed"""

def before(A: T.Buffer[1, "int32"]):
A[0] = 1 + T.undef(dtype="int32")

def expected(A: T.Buffer[1, "int32"]):
T.evaluate(0)


class TestKeepOtherCallNodes(BaseBeforeAfter):
"""Expressions containing other CallNodes are not removed"""

def before(A: T.Buffer[1, "int32"], n: T.int32):
A[0] = T.shift_left(n, 1, dtype="int32")

expected = before


class TestRemoveLetUndef(BaseBeforeAfter):
"""Remove a store whose value is bound to T.undef()"""

def before(A: T.Buffer[1, "int32"]):
val = T.undef(dtype="int32")
A[0] = val

def expected(A: T.Buffer[1, "int32"]):
T.evaluate(0)


class TestRaiseErrorForUndefAsStoreIndices(BaseBeforeAfter):
"""Use of T.undef() as buffer indices is an error"""

def before(A: T.Buffer[1, "int32"]):
val = T.undef(dtype="int32")
A[val] = 5

expected = TVMError


class TestRaiseErrorForUndefAsLoadIndices(BaseBeforeAfter):
"""Use of T.undef() as buffer indices is an error
Even though this occurs as part of the BufferStore's value, the
T.undef() may not appear in a buffer's indices.
"""

def before(A: T.Buffer[1, "int32"], B: T.Buffer[1, "int32"]):
B[0] = A[T.undef(dtype="int32")]

expected = TVMError


if __name__ == "__main__":
tvm.testing.main()

0 comments on commit c4aab62

Please sign in to comment.