Skip to content

Commit

Permalink
[TE] Support schedulable TIR compute definitions in TOPI (#11589)
Browse files Browse the repository at this point in the history
This PR adds `te.extern_primfunc` which provides the interface around TE ExternOp that allows a TVMScript defined schedulable TIR PrimFunc to be inlined into a TE compute graph. The result is that TIR can be used for compute definitions in Relay OpStrategies and, paired with meta-scheduler support in relay as introduced in #10578, these compute definitions can be scheduled and tuned as demonstrated in the attached tests.  

Prior to this, compute definitions were limited to those definable in TE only. As a consequence of this patch and ongoing improvements to TVMScript meta-programming (#11097), TOPI can be extended to include compute and scheduling functions targeting schedulable TIR uniformly.
  • Loading branch information
csullivan authored Jun 13, 2022
1 parent e61ad7a commit 1420df7
Show file tree
Hide file tree
Showing 7 changed files with 668 additions and 53 deletions.
1 change: 1 addition & 0 deletions python/tvm/te/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@
from .operation import placeholder, compute, scan, extern, var, size_var, const
from .operation import thread_axis, reduce_axis
from .operation import create_prim_func
from .operation import extern_primfunc

from .tensor import PlaceholderOp, ComputeOp, TensorComputeOp, ScanOp, ExternOp, HybridOp
from .autodiff import gradient
Expand Down
82 changes: 82 additions & 0 deletions python/tvm/te/operation.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
import tvm._ffi
import tvm.tir
import tvm.tir._ffi_api
import tvm.arith._ffi_api
from tvm._ffi.base import string_types
from tvm.ir import Array
from tvm.runtime import convert
Expand Down Expand Up @@ -354,6 +355,87 @@ def extern(
return res[0] if len(res) == 1 else res


def extern_primfunc(input_tensors: List[_tensor.Tensor], primfunc: tvm.tir.PrimFunc, **kwargs):
"""Compute tensors via a schedulable TIR PrimFunc
Parameters
----------
input_tensors: list of Tensor
Input tensors that map to the corresponding primfunc input params.
primfunc: PrimFunc
The TIR PrimFunc
Returns
-------
tensor: Tensor or list of Tensors
The created tensor or tuple of tensors if it contains multiple outputs.
Example
-------
In the code below, a TVMScript defined TIR PrimFunc is inlined into
a TE ExternOp. Applying te.create_prim_func on this
.. code-block:: python
A = te.placeholder((128, 128), name="A")
B = te.placeholder((128, 128), name="B")
@T.prim_func
def before_split(a: T.handle, b: T.handle) -> None:
A = T.match_buffer(a, (128, 128))
B = T.match_buffer(b, (128, 128))
for i, j in T.grid(128, 128):
with T.block("B"):
vi, vj = T.axis.remap("SS", [i, j])
B[vi, vj] = A[vi, vj] * 2.0
C = te.extern_primfunc([A, B], func)
"""
access_map = {
k: tuple(v) for k, v in tvm.arith._ffi_api.DomainTouchedAccessMap(primfunc).items()
}
in_buffers = [buf for buf, access in access_map.items() if len(access[0])]
out_buffers = [buf for buf, access in access_map.items() if len(access[1])]
assert in_buffers, "PrimFunc has no input buffers"
assert out_buffers, "PrimFunc has no output buffers"

outputs = []
inplace = []
input_buffers = in_buffers
for obuf in out_buffers:
if obuf in in_buffers:
inplace.append(obuf)
else:
outputs.append(obuf)

if not outputs:
iobuf = inplace.pop()
input_buffers.remove(iobuf)
outputs = [iobuf]

assert len(input_buffers) == len(input_tensors), (
"The number of provided input input_tensors does not match the number of ",
"input buffers in the primfunc",
)
for tensor, buffer in zip(input_tensors, input_buffers):
# TODO(csullivan): Can a stronger comparison between Tensor<>Buffer be made?
assert tensor.shape == buffer.shape, (
"The input input_tensors provided do not match the input buffers in the ",
"primfunc. Please check that the order of input te.Input_Tensors and the ",
"order of the primfunc variables in the params list agree.",
)
output = extern(
[buf.shape for buf in outputs],
input_tensors,
lambda ins, outs: primfunc.body,
in_buffers=input_buffers,
out_buffers=outputs,
**kwargs,
)
return output


def var(name="tindex", dtype="int32", span=None):
"""Create a new variable with specified name and dtype
Expand Down
106 changes: 86 additions & 20 deletions src/arith/domain_touched.cc
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
#include <tvm/tir/expr.h>
#include <tvm/tir/stmt_functor.h>

#include <tuple>
#include <unordered_map>
#include <unordered_set>

Expand All @@ -34,18 +35,54 @@ namespace arith {

using namespace tir;

namespace {

using BufferTouches = std::vector<std::vector<IntSet>>;

struct LoadAccess {
BufferTouches set;
};

struct StoreAccess {
BufferTouches set;
};

struct CombinedAccess {
BufferTouches set;
};

using BufferDomainAccess = std::tuple<LoadAccess, StoreAccess, CombinedAccess>;

} // namespace

// Find Read region of the tensor in the stmt.
class BufferTouchedDomain final : public StmtExprVisitor {
public:
BufferTouchedDomain(const Buffer& buffer, bool consider_loads, bool consider_stores)
: buffer_(buffer), consider_loads_(consider_loads), consider_stores_(consider_stores) {}
BufferTouchedDomain(const Stmt& stmt) { operator()(stmt); }

std::unordered_map<const BufferNode*, BufferDomainAccess>& GetAccessedBufferRegions() {
return buffer_access_map_;
}

Region FindUnion(const Buffer& buffer, bool consider_loads, bool consider_stores) {
auto kv = buffer_access_map_.find(buffer.get());
CHECK(kv != buffer_access_map_.end())
<< "The requested buffer is not contained in the provided stmt body.";

Region Find(const Stmt& stmt) {
operator()(stmt);
Region ret;
Range none;
for (size_t i = 0; i < bounds_.size(); ++i) {
ret.push_back(arith::Union(bounds_[i]).CoverRange(none));
BufferTouches bounds;
if (consider_loads && consider_stores) {
bounds = std::get<CombinedAccess>(kv->second).set;
} else if (consider_loads) {
bounds = std::get<LoadAccess>(kv->second).set;
} else if (consider_stores) {
bounds = std::get<StoreAccess>(kv->second).set;
} else {
CHECK(false) << "Must consider at least on of either loads and stores, but both are false";
}
for (size_t i = 0; i < bounds.size(); ++i) {
ret.push_back(arith::Union(bounds[i]).CoverRange(none));
}
return ret;
}
Expand Down Expand Up @@ -78,41 +115,70 @@ class BufferTouchedDomain final : public StmtExprVisitor {
}

void VisitExpr_(const BufferLoadNode* op) final {
if (consider_loads_ && buffer_.same_as(op->buffer)) {
Touch(op->indices);
}
// Record load-exclusive buffer access
Touch(&std::get<LoadAccess>(buffer_access_map_[op->buffer.get()]).set, op->indices);
// Record load-store inclusive buffer access
Touch(&std::get<CombinedAccess>(buffer_access_map_[op->buffer.get()]).set, op->indices);
StmtExprVisitor::VisitExpr_(op);
}

void VisitStmt_(const BufferStoreNode* op) final {
if (consider_stores_ && buffer_.same_as(op->buffer)) {
Touch(op->indices);
}
// Record store-exclusive buffer access
Touch(&std::get<StoreAccess>(buffer_access_map_[op->buffer.get()]).set, op->indices);
// Record load-store inclusive buffer access
Touch(&std::get<CombinedAccess>(buffer_access_map_[op->buffer.get()]).set, op->indices);
StmtExprVisitor::VisitStmt_(op);
}

private:
void Touch(const Array<PrimExpr>& args) {
if (args.size() > bounds_.size()) {
bounds_.resize(args.size());
template <typename ArrayType>
void Touch(BufferTouches* bounds, const ArrayType& args) const {
if (args.size() > bounds->size()) {
bounds->resize(args.size());
}
for (size_t i = 0; i < args.size(); ++i) {
bounds_[i].emplace_back(EvalSet(args[i], dom_map_));
(*bounds)[i].emplace_back(EvalSet(args[i], dom_map_));
}
}

const Buffer& buffer_;
bool consider_loads_, consider_stores_;
std::vector<std::vector<IntSet> > bounds_;
std::unordered_map<const BufferNode*, BufferDomainAccess> buffer_access_map_;
std::unordered_map<const VarNode*, IntSet> dom_map_;
};

Region DomainTouched(const Stmt& stmt, const Buffer& buffer, bool consider_loads,
bool consider_stores) {
return BufferTouchedDomain(buffer, consider_loads, consider_stores).Find(stmt);
return BufferTouchedDomain(stmt).FindUnion(buffer, consider_loads, consider_stores);
}

Map<Buffer, runtime::ADT> DomainTouchedAccessMap(const PrimFunc& func) {
auto buffer_access_map = BufferTouchedDomain(func->body).GetAccessedBufferRegions();
Map<Buffer, runtime::ADT> ret;
auto& buffer_map = func->buffer_map;
for (auto& var : func->params) {
auto& buffer = buffer_map[var];
auto& access = buffer_access_map[buffer.get()];
Array<Array<IntSet>> loads, stores, combined;
for (std::vector<IntSet>& touch : std::get<LoadAccess>(access).set) {
loads.push_back(Array<IntSet>(touch));
}
for (std::vector<IntSet>& touch : std::get<StoreAccess>(access).set) {
stores.push_back(Array<IntSet>(touch));
}
for (std::vector<IntSet>& touch : std::get<CombinedAccess>(access).set) {
combined.push_back(Array<IntSet>(touch));
}

std::vector<ObjectRef> fields;
fields.push_back(loads);
fields.push_back(stores);
fields.push_back(combined);
ret.Set(buffer, runtime::ADT::Tuple(fields));
}
return ret;
}

TVM_REGISTER_GLOBAL("arith.DomainTouched").set_body_typed(DomainTouched);
TVM_REGISTER_GLOBAL("arith.DomainTouchedAccessMap").set_body_typed(DomainTouchedAccessMap);

} // namespace arith
} // namespace tvm
2 changes: 1 addition & 1 deletion src/relay/backend/task_extraction.cc
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,7 @@ bool DefaultTaskFilter(const Array<te::Tensor>& args) {
stack.pop_back();
if (tensor->op->IsInstance<PlaceholderOpNode>()) {
// do nothing
} else if (tensor->op->IsInstance<ComputeOpNode>()) {
} else if (tensor->op->IsInstance<ComputeOpNode>() || tensor->op->IsInstance<ExternOpNode>()) {
Array<Tensor> inputs = tensor->op->InputTensors();
for (const Tensor& v : inputs) {
if (!visited.count(v.get())) {
Expand Down
Loading

0 comments on commit 1420df7

Please sign in to comment.