Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[TE] Support schedulable TIR compute definitions in TOPI #11589

Merged
merged 7 commits into from
Jun 13, 2022
1 change: 1 addition & 0 deletions python/tvm/te/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@
from .operation import placeholder, compute, scan, extern, var, size_var, const
from .operation import thread_axis, reduce_axis
from .operation import create_prim_func
from .operation import extern_primfunc

from .tensor import PlaceholderOp, ComputeOp, TensorComputeOp, ScanOp, ExternOp, HybridOp
from .autodiff import gradient
Expand Down
82 changes: 82 additions & 0 deletions python/tvm/te/operation.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
import tvm._ffi
import tvm.tir
import tvm.tir._ffi_api
import tvm.arith._ffi_api
from tvm._ffi.base import string_types
from tvm.ir import Array
from tvm.runtime import convert
Expand Down Expand Up @@ -354,6 +355,87 @@ def extern(
return res[0] if len(res) == 1 else res


def extern_primfunc(input_tensors: List[_tensor.Tensor], primfunc: tvm.tir.PrimFunc, **kwargs):
"""Compute tensors via a schedulable TIR PrimFunc

Parameters
----------
input_tensors: list of Tensor
Input tensors that map to the corresponding primfunc input params.

primfunc: PrimFunc
The TIR PrimFunc

Returns
-------
tensor: Tensor or list of Tensors
The created tensor or tuple of tensors if it contains multiple outputs.

Example
-------
In the code below, a TVMScript defined TIR PrimFunc is inlined into
a TE ExternOp. Applying te.create_prim_func on this

.. code-block:: python

A = te.placeholder((128, 128), name="A")
B = te.placeholder((128, 128), name="B")

@T.prim_func
def before_split(a: T.handle, b: T.handle) -> None:
A = T.match_buffer(a, (128, 128))
B = T.match_buffer(b, (128, 128))
for i, j in T.grid(128, 128):
with T.block("B"):
vi, vj = T.axis.remap("SS", [i, j])
B[vi, vj] = A[vi, vj] * 2.0

C = te.extern_primfunc([A, B], func)
"""
access_map = {
k: tuple(v) for k, v in tvm.arith._ffi_api.DomainTouchedAccessMap(primfunc).items()
}
in_buffers = [buf for buf, access in access_map.items() if len(access[0])]
out_buffers = [buf for buf, access in access_map.items() if len(access[1])]
assert in_buffers, "PrimFunc has no input buffers"
assert out_buffers, "PrimFunc has no output buffers"

outputs = []
inplace = []
input_buffers = in_buffers
for obuf in out_buffers:
if obuf in in_buffers:
inplace.append(obuf)
else:
outputs.append(obuf)

if not outputs:
iobuf = inplace.pop()
input_buffers.remove(iobuf)
outputs = [iobuf]

assert len(input_buffers) == len(input_tensors), (
"The number of provided input input_tensors does not match the number of ",
"input buffers in the primfunc",
)
for tensor, buffer in zip(input_tensors, input_buffers):
# TODO(csullivan): Can a stronger comparison between Tensor<>Buffer be made?
assert tensor.shape == buffer.shape, (
"The input input_tensors provided do not match the input buffers in the ",
"primfunc. Please check that the order of input te.Input_Tensors and the ",
"order of the primfunc variables in the params list agree.",
)
output = extern(
[buf.shape for buf in outputs],
input_tensors,
lambda ins, outs: primfunc.body,
in_buffers=input_buffers,
out_buffers=outputs,
**kwargs,
)
return output


def var(name="tindex", dtype="int32", span=None):
"""Create a new variable with specified name and dtype

Expand Down
100 changes: 80 additions & 20 deletions src/arith/domain_touched.cc
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
#include <tvm/tir/expr.h>
#include <tvm/tir/stmt_functor.h>

#include <tuple>
#include <unordered_map>
#include <unordered_set>

Expand All @@ -34,18 +35,48 @@ namespace arith {

using namespace tir;

struct LoadAccess {
Hzfengsy marked this conversation as resolved.
Show resolved Hide resolved
std::vector<std::vector<IntSet>> set;
};

struct StoreAccess {
std::vector<std::vector<IntSet>> set;
};

struct CombinedAccess {
std::vector<std::vector<IntSet>> set;
};

using BufferDomainAccess = std::tuple<LoadAccess, StoreAccess, CombinedAccess>;

// Find Read region of the tensor in the stmt.
class BufferTouchedDomain final : public StmtExprVisitor {
public:
BufferTouchedDomain(const Buffer& buffer, bool consider_loads, bool consider_stores)
: buffer_(buffer), consider_loads_(consider_loads), consider_stores_(consider_stores) {}
BufferTouchedDomain(const Stmt& stmt) { operator()(stmt); }

std::unordered_map<const BufferNode*, BufferDomainAccess>& GetAccessedBufferRegions() {
return buffer_access_map_;
}

Region FindUnion(const Buffer& buffer, bool consider_loads, bool consider_stores) {
auto kv = buffer_access_map_.find(buffer.get());
CHECK(kv != buffer_access_map_.end())
<< "The requested buffer is not contained in the provided stmt body.";

Region Find(const Stmt& stmt) {
operator()(stmt);
Region ret;
Range none;
for (size_t i = 0; i < bounds_.size(); ++i) {
ret.push_back(arith::Union(bounds_[i]).CoverRange(none));
std::vector<std::vector<IntSet>> bounds; //= kv->second[set_index];
if (consider_loads && consider_stores) {
bounds = std::get<CombinedAccess>(kv->second).set;
} else if (consider_loads) {
bounds = std::get<LoadAccess>(kv->second).set;
} else if (consider_stores) {
bounds = std::get<StoreAccess>(kv->second).set;
} else {
CHECK(false) << "Must consider at least on of either loads and stores, but both are false";
}
for (size_t i = 0; i < bounds.size(); ++i) {
ret.push_back(arith::Union(bounds[i]).CoverRange(none));
}
return ret;
}
Expand Down Expand Up @@ -78,41 +109,70 @@ class BufferTouchedDomain final : public StmtExprVisitor {
}

void VisitExpr_(const BufferLoadNode* op) final {
if (consider_loads_ && buffer_.same_as(op->buffer)) {
Touch(op->indices);
}
// Record load-exclusive buffer access
Touch(&std::get<LoadAccess>(buffer_access_map_[op->buffer.get()]).set, op->indices);
// Record load-store inclusive buffer access
Touch(&std::get<CombinedAccess>(buffer_access_map_[op->buffer.get()]).set, op->indices);
StmtExprVisitor::VisitExpr_(op);
}

void VisitStmt_(const BufferStoreNode* op) final {
if (consider_stores_ && buffer_.same_as(op->buffer)) {
Touch(op->indices);
}
// Record store-exclusive buffer access
Touch(&std::get<StoreAccess>(buffer_access_map_[op->buffer.get()]).set, op->indices);
// Record load-store inclusive buffer access
Touch(&std::get<CombinedAccess>(buffer_access_map_[op->buffer.get()]).set, op->indices);
StmtExprVisitor::VisitStmt_(op);
}

private:
void Touch(const Array<PrimExpr>& args) {
if (args.size() > bounds_.size()) {
bounds_.resize(args.size());
template <typename ArrayType>
void Touch(std::vector<std::vector<IntSet>>* bounds, const ArrayType& args) const {
if (args.size() > bounds->size()) {
bounds->resize(args.size());
}
for (size_t i = 0; i < args.size(); ++i) {
bounds_[i].emplace_back(EvalSet(args[i], dom_map_));
(*bounds)[i].emplace_back(EvalSet(args[i], dom_map_));
}
}

const Buffer& buffer_;
bool consider_loads_, consider_stores_;
std::vector<std::vector<IntSet> > bounds_;
std::unordered_map<const BufferNode*, BufferDomainAccess> buffer_access_map_;
std::unordered_map<const VarNode*, IntSet> dom_map_;
};

Region DomainTouched(const Stmt& stmt, const Buffer& buffer, bool consider_loads,
bool consider_stores) {
return BufferTouchedDomain(buffer, consider_loads, consider_stores).Find(stmt);
return BufferTouchedDomain(stmt).FindUnion(buffer, consider_loads, consider_stores);
}

Map<Buffer, runtime::ADT> DomainTouchedAccessMap(const PrimFunc& func) {
auto buffer_access_map = BufferTouchedDomain(func->body).GetAccessedBufferRegions();
Map<Buffer, runtime::ADT> ret;
auto& buffer_map = func->buffer_map;
for (auto& var : func->params) {
auto& buffer = buffer_map[var];
auto& access = buffer_access_map[buffer.get()];
Array<Array<IntSet>> loads, stores, combined;
for (std::vector<IntSet>& touch : std::get<LoadAccess>(access).set) {
loads.push_back(Array<IntSet>(touch));
}
for (std::vector<IntSet>& touch : std::get<StoreAccess>(access).set) {
stores.push_back(Array<IntSet>(touch));
}
for (std::vector<IntSet>& touch : std::get<CombinedAccess>(access).set) {
combined.push_back(Array<IntSet>(touch));
}

std::vector<ObjectRef> fields;
fields.push_back(loads);
fields.push_back(stores);
fields.push_back(combined);
ret.Set(buffer, runtime::ADT::Tuple(fields));
}
return ret;
}

TVM_REGISTER_GLOBAL("arith.DomainTouched").set_body_typed(DomainTouched);
TVM_REGISTER_GLOBAL("arith.DomainTouchedAccessMap").set_body_typed(DomainTouchedAccessMap);

} // namespace arith
} // namespace tvm
2 changes: 1 addition & 1 deletion src/relay/backend/task_extraction.cc
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,7 @@ bool DefaultTaskFilter(const Array<te::Tensor>& args) {
stack.pop_back();
if (tensor->op->IsInstance<PlaceholderOpNode>()) {
// do nothing
} else if (tensor->op->IsInstance<ComputeOpNode>()) {
} else if (tensor->op->IsInstance<ComputeOpNode>() || tensor->op->IsInstance<ExternOpNode>()) {
Array<Tensor> inputs = tensor->op->InputTensors();
for (const Tensor& v : inputs) {
if (!visited.count(v.get())) {
Expand Down
Loading