diff --git a/python/tvm/te/__init__.py b/python/tvm/te/__init__.py index 1777d8707c7c..a52422f6c1d2 100644 --- a/python/tvm/te/__init__.py +++ b/python/tvm/te/__init__.py @@ -40,6 +40,7 @@ from .operation import placeholder, compute, scan, extern, var, size_var, const from .operation import thread_axis, reduce_axis from .operation import create_prim_func +from .operation import extern_primfunc from .tensor import PlaceholderOp, ComputeOp, TensorComputeOp, ScanOp, ExternOp, HybridOp from .autodiff import gradient diff --git a/python/tvm/te/operation.py b/python/tvm/te/operation.py index df5dd2c4ffd8..ada5c369ad3b 100644 --- a/python/tvm/te/operation.py +++ b/python/tvm/te/operation.py @@ -24,6 +24,7 @@ import tvm._ffi import tvm.tir import tvm.tir._ffi_api +import tvm.arith._ffi_api from tvm._ffi.base import string_types from tvm.ir import Array from tvm.runtime import convert @@ -354,6 +355,87 @@ def extern( return res[0] if len(res) == 1 else res +def extern_primfunc(input_tensors: List[_tensor.Tensor], primfunc: tvm.tir.PrimFunc, **kwargs): + """Compute tensors via a schedulable TIR PrimFunc + + Parameters + ---------- + input_tensors: list of Tensor + Input tensors that map to the corresponding primfunc input params. + + primfunc: PrimFunc + The TIR PrimFunc + + Returns + ------- + tensor: Tensor or list of Tensors + The created tensor or tuple of tensors if it contains multiple outputs. + + Example + ------- + In the code below, a TVMScript defined TIR PrimFunc is inlined into + a TE ExternOp. Applying te.create_prim_func on this + + .. code-block:: python + + A = te.placeholder((128, 128), name="A") + B = te.placeholder((128, 128), name="B") + + @T.prim_func + def before_split(a: T.handle, b: T.handle) -> None: + A = T.match_buffer(a, (128, 128)) + B = T.match_buffer(b, (128, 128)) + for i, j in T.grid(128, 128): + with T.block("B"): + vi, vj = T.axis.remap("SS", [i, j]) + B[vi, vj] = A[vi, vj] * 2.0 + + C = te.extern_primfunc([A, B], func) + """ + access_map = { + k: tuple(v) for k, v in tvm.arith._ffi_api.DomainTouchedAccessMap(primfunc).items() + } + in_buffers = [buf for buf, access in access_map.items() if len(access[0])] + out_buffers = [buf for buf, access in access_map.items() if len(access[1])] + assert in_buffers, "PrimFunc has no input buffers" + assert out_buffers, "PrimFunc has no output buffers" + + outputs = [] + inplace = [] + input_buffers = in_buffers + for obuf in out_buffers: + if obuf in in_buffers: + inplace.append(obuf) + else: + outputs.append(obuf) + + if not outputs: + iobuf = inplace.pop() + input_buffers.remove(iobuf) + outputs = [iobuf] + + assert len(input_buffers) == len(input_tensors), ( + "The number of provided input input_tensors does not match the number of ", + "input buffers in the primfunc", + ) + for tensor, buffer in zip(input_tensors, input_buffers): + # TODO(csullivan): Can a stronger comparison between Tensor<>Buffer be made? + assert tensor.shape == buffer.shape, ( + "The input input_tensors provided do not match the input buffers in the ", + "primfunc. Please check that the order of input te.Input_Tensors and the ", + "order of the primfunc variables in the params list agree.", + ) + output = extern( + [buf.shape for buf in outputs], + input_tensors, + lambda ins, outs: primfunc.body, + in_buffers=input_buffers, + out_buffers=outputs, + **kwargs, + ) + return output + + def var(name="tindex", dtype="int32", span=None): """Create a new variable with specified name and dtype diff --git a/src/arith/domain_touched.cc b/src/arith/domain_touched.cc index 3c3da5f4b99b..8874f4f16506 100644 --- a/src/arith/domain_touched.cc +++ b/src/arith/domain_touched.cc @@ -26,6 +26,7 @@ #include #include +#include #include #include @@ -34,18 +35,54 @@ namespace arith { using namespace tir; +namespace { + +using BufferTouches = std::vector>; + +struct LoadAccess { + BufferTouches set; +}; + +struct StoreAccess { + BufferTouches set; +}; + +struct CombinedAccess { + BufferTouches set; +}; + +using BufferDomainAccess = std::tuple; + +} // namespace + // Find Read region of the tensor in the stmt. class BufferTouchedDomain final : public StmtExprVisitor { public: - BufferTouchedDomain(const Buffer& buffer, bool consider_loads, bool consider_stores) - : buffer_(buffer), consider_loads_(consider_loads), consider_stores_(consider_stores) {} + BufferTouchedDomain(const Stmt& stmt) { operator()(stmt); } + + std::unordered_map& GetAccessedBufferRegions() { + return buffer_access_map_; + } + + Region FindUnion(const Buffer& buffer, bool consider_loads, bool consider_stores) { + auto kv = buffer_access_map_.find(buffer.get()); + CHECK(kv != buffer_access_map_.end()) + << "The requested buffer is not contained in the provided stmt body."; - Region Find(const Stmt& stmt) { - operator()(stmt); Region ret; Range none; - for (size_t i = 0; i < bounds_.size(); ++i) { - ret.push_back(arith::Union(bounds_[i]).CoverRange(none)); + BufferTouches bounds; + if (consider_loads && consider_stores) { + bounds = std::get(kv->second).set; + } else if (consider_loads) { + bounds = std::get(kv->second).set; + } else if (consider_stores) { + bounds = std::get(kv->second).set; + } else { + CHECK(false) << "Must consider at least on of either loads and stores, but both are false"; + } + for (size_t i = 0; i < bounds.size(); ++i) { + ret.push_back(arith::Union(bounds[i]).CoverRange(none)); } return ret; } @@ -78,41 +115,70 @@ class BufferTouchedDomain final : public StmtExprVisitor { } void VisitExpr_(const BufferLoadNode* op) final { - if (consider_loads_ && buffer_.same_as(op->buffer)) { - Touch(op->indices); - } + // Record load-exclusive buffer access + Touch(&std::get(buffer_access_map_[op->buffer.get()]).set, op->indices); + // Record load-store inclusive buffer access + Touch(&std::get(buffer_access_map_[op->buffer.get()]).set, op->indices); StmtExprVisitor::VisitExpr_(op); } void VisitStmt_(const BufferStoreNode* op) final { - if (consider_stores_ && buffer_.same_as(op->buffer)) { - Touch(op->indices); - } + // Record store-exclusive buffer access + Touch(&std::get(buffer_access_map_[op->buffer.get()]).set, op->indices); + // Record load-store inclusive buffer access + Touch(&std::get(buffer_access_map_[op->buffer.get()]).set, op->indices); StmtExprVisitor::VisitStmt_(op); } private: - void Touch(const Array& args) { - if (args.size() > bounds_.size()) { - bounds_.resize(args.size()); + template + void Touch(BufferTouches* bounds, const ArrayType& args) const { + if (args.size() > bounds->size()) { + bounds->resize(args.size()); } for (size_t i = 0; i < args.size(); ++i) { - bounds_[i].emplace_back(EvalSet(args[i], dom_map_)); + (*bounds)[i].emplace_back(EvalSet(args[i], dom_map_)); } } - const Buffer& buffer_; - bool consider_loads_, consider_stores_; - std::vector > bounds_; + std::unordered_map buffer_access_map_; std::unordered_map dom_map_; }; Region DomainTouched(const Stmt& stmt, const Buffer& buffer, bool consider_loads, bool consider_stores) { - return BufferTouchedDomain(buffer, consider_loads, consider_stores).Find(stmt); + return BufferTouchedDomain(stmt).FindUnion(buffer, consider_loads, consider_stores); +} + +Map DomainTouchedAccessMap(const PrimFunc& func) { + auto buffer_access_map = BufferTouchedDomain(func->body).GetAccessedBufferRegions(); + Map ret; + auto& buffer_map = func->buffer_map; + for (auto& var : func->params) { + auto& buffer = buffer_map[var]; + auto& access = buffer_access_map[buffer.get()]; + Array> loads, stores, combined; + for (std::vector& touch : std::get(access).set) { + loads.push_back(Array(touch)); + } + for (std::vector& touch : std::get(access).set) { + stores.push_back(Array(touch)); + } + for (std::vector& touch : std::get(access).set) { + combined.push_back(Array(touch)); + } + + std::vector fields; + fields.push_back(loads); + fields.push_back(stores); + fields.push_back(combined); + ret.Set(buffer, runtime::ADT::Tuple(fields)); + } + return ret; } TVM_REGISTER_GLOBAL("arith.DomainTouched").set_body_typed(DomainTouched); +TVM_REGISTER_GLOBAL("arith.DomainTouchedAccessMap").set_body_typed(DomainTouchedAccessMap); } // namespace arith } // namespace tvm diff --git a/src/relay/backend/task_extraction.cc b/src/relay/backend/task_extraction.cc index 6ec881111d77..421a92c245e7 100644 --- a/src/relay/backend/task_extraction.cc +++ b/src/relay/backend/task_extraction.cc @@ -52,7 +52,7 @@ bool DefaultTaskFilter(const Array& args) { stack.pop_back(); if (tensor->op->IsInstance()) { // do nothing - } else if (tensor->op->IsInstance()) { + } else if (tensor->op->IsInstance() || tensor->op->IsInstance()) { Array inputs = tensor->op->InputTensors(); for (const Tensor& v : inputs) { if (!visited.count(v.get())) { diff --git a/src/te/operation/create_primfunc.cc b/src/te/operation/create_primfunc.cc index 27cfdd605c5d..2aeb799a04cb 100644 --- a/src/te/operation/create_primfunc.cc +++ b/src/te/operation/create_primfunc.cc @@ -395,8 +395,7 @@ Stmt GenerateStmtFromExternOp(const te::ExternOp& extern_op, CreateFuncInfo* inf /*annotations=*/extern_op->attrs)); } -PrimFunc CreatePrimFunc(const Array& arg_list) { - // Step 1. Create tensor read graph. +Array CollectOrderedOps(const Array& arg_list) { Array arg_ops; for (const te::Tensor& arg : arg_list) { arg_ops.push_back(arg->op); @@ -404,53 +403,67 @@ PrimFunc CreatePrimFunc(const Array& arg_list) { te::ReadGraph g = te::CreateReadGraph(arg_ops); Array order = te::PostDFSOrder(arg_ops, g); - // Step 2. Checking all Operations are supported. for (const te::Operation& op : order) { if (!(op->IsInstance() || op->IsInstance() || op->IsInstance())) LOG(FATAL) << "TypeError: Unsupported Operation: " << op->GetTypeKey() << ". " << "Only te.placeholder and te.compute are allowed for now."; } + return order; +} - // Infomations used in CreatePrimFunc and its sub-functions. - CreateFuncInfo info(arg_list); - // Root body stmts. - Array root_stmts; - // Analyzer - arith::Analyzer analyzer; +void InitializeBufferBinds(const Array& ordered_ops, CreateFuncInfo* info) { + // Process any TE operations which contain user defined buffers + for (const auto& op : ordered_ops) { + // Initialize the tensor2buffer binds map with buffers defined by the te.extern + if (const auto* extern_op = op.as()) { + ICHECK_EQ(extern_op->inputs.size(), extern_op->input_placeholders.size()); + for (size_t i = 0; i < extern_op->inputs.size(); ++i) { + const te::Tensor& input = extern_op->inputs[i]; + const Buffer& buffer = extern_op->input_placeholders[i]; + info->tensor2buffers[input] = buffer; + } + } + } +} - // Step 3. Rewrite compute stages into blocks. - for (const te::Operation& op : order) { - if (const auto* placeholder = op.as()) { - // Case 1. PlaceholderOp (te.placeholder) - ICHECK_EQ(op->num_outputs(), 1); - const te::Tensor& tensor = op.output(0); - // Check op is in op list - ICHECK(info.IsArg(tensor)); +void RewriteStageToBlock(const te::Operation& op, CreateFuncInfo* info, Array* root_stmts, + arith::Analyzer* analyzer) { + if (const auto* placeholder = op.as()) { + // Case 1. PlaceholderOp (te.placeholder) + ICHECK_EQ(op->num_outputs(), 1); + const te::Tensor& tensor = op.output(0); + // Check op is in op list + ICHECK(info->IsArg(tensor)); + // Declare a buffer for any argument tensors without a pre-existing + // buffer declaration recorded in the tensor2buffer binds map + if (info->tensor2buffers.count(tensor) == 0) { const Buffer& buffer = decl_buffer(placeholder->shape, placeholder->dtype, placeholder->name, "global"); - info.tensor2buffers[tensor] = buffer; - } else if (const auto* compute_op = op.as()) { - // Case 2. ComputeOp (te.compute) - root_stmts.push_back( - GenerateStmtFromCompute(GetRef(compute_op), &info, &analyzer)); - } else if (const auto extern_op = op.as()) { - // Case 3. ExternOp (te.extern) - root_stmts.push_back(GenerateStmtFromExternOp(GetRef(extern_op), &info)); - } else { - ICHECK(false) << "TypeError: Unsupported Operation: " << op->GetTypeKey() << ". " - << "Only te.placeholder and te.compute are allowed for now."; + info->tensor2buffers[tensor] = buffer; } + } else if (const auto* compute_op = op.as()) { + // Case 2. ComputeOp (te.compute) + root_stmts->push_back( + GenerateStmtFromCompute(GetRef(compute_op), info, analyzer)); + } else if (const auto extern_op = op.as()) { + // Case 3. ExternOp (te.extern) + root_stmts->push_back(GenerateStmtFromExternOp(GetRef(extern_op), info)); + } else { + ICHECK(false) << "TypeError: Unsupported Operation: " << op->GetTypeKey() << ". " + << "Only te.placeholder and te.compute are allowed for now."; } +} - // Step 4. Create func and complete it. +PrimFunc GenerateAndCompletePrimFunc(const Array& arg_list, + const Array& root_stmts, CreateFuncInfo* info) { Array parameters; Map buffer_map; for (const te::Tensor& tensor : arg_list) { Var arg("var_" + tensor->GetNameHint(), PrimType(DataType::Handle())); parameters.push_back(arg); - auto it = info.tensor2buffers.find(tensor); - ICHECK(it != info.tensor2buffers.end()); + auto it = info->tensor2buffers.find(tensor); + ICHECK(it != info->tensor2buffers.end()); buffer_map.Set(arg, it->second); } PrimFunc func = WithAttrs(PrimFunc(/*params=*/std::move(parameters), @@ -460,10 +473,32 @@ PrimFunc CreatePrimFunc(const Array& arg_list) { {{"global_symbol", String("main")}, {"tir.noalias", Bool(true)}}); const auto* complete = runtime::Registry::Get("script.Complete"); ICHECK(complete); - func = (*complete)(std::move(func), info.root_alloc); + func = (*complete)(std::move(func), info->root_alloc); return LayoutFreePlaceholdersNormalizer().Process(std::move(func)); } +PrimFunc CreatePrimFunc(const Array& arg_list) { + // Infomations used in CreatePrimFunc and its sub-functions. + CreateFuncInfo info(arg_list); + // Root body stmts. + Array root_stmts; + // Analyzer + arith::Analyzer analyzer; + + // Step 1. Create ordered array of operations and validate they are supported. + Array order = CollectOrderedOps(arg_list); + + // Step 2. Initialize buffer binds map + InitializeBufferBinds(order, &info); + + // Step 3. Rewrite compute stages into blocks. + for (const te::Operation& op : order) { + RewriteStageToBlock(op, &info, &root_stmts, &analyzer); + } + // Step 4. Create func and complete prim func. + return GenerateAndCompletePrimFunc(arg_list, root_stmts, &info); +} + TVM_REGISTER_GLOBAL("te.CreatePrimFunc").set_body_typed(CreatePrimFunc); } // namespace tir diff --git a/tests/python/unittest/test_meta_schedule_relay_tir_compute.py b/tests/python/unittest/test_meta_schedule_relay_tir_compute.py new file mode 100644 index 000000000000..b62b638c03dc --- /dev/null +++ b/tests/python/unittest/test_meta_schedule_relay_tir_compute.py @@ -0,0 +1,174 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +import numpy as np +import tvm +import tvm.testing +import tvm.topi.testing + +from tvm.script import tir as T +from tvm import tir, te, relay, topi, autotvm +from tvm.relay.testing.temp_op_attr import TempOpAttr +from tvm.meta_schedule import ApplyHistoryBest +from tvm.meta_schedule.testing import apply_fixed_schedules + + +def compute_tir_conv2d_nchw_oihw(data_shape, weight_shape, dtype): + assert dtype == "float32" + OC, IC, FH, FW = weight_shape + + padding = (0, 0, 0, 0) + strides = (1, 1) + dilation = (1, 1) + output_shape = ( + data_shape[0], + weight_shape[0], + (data_shape[2] - ((weight_shape[2] - 1) * dilation[0] + 1) + padding[0] + padding[1]) + // strides[0] + + 1, + (data_shape[3] - ((weight_shape[3] - 1) * dilation[1] + 1) + padding[2] + padding[3]) + // strides[1] + + 1, + ) + N, K, BH, BW = output_shape + + # fmt: off + @T.prim_func + def conv2d(a: T.handle, filt: T.handle, b: T.handle) -> None: + T.func_attr({"global_symbol": "main", "tir.noalias": True}) + A = T.match_buffer(a, data_shape, dtype=dtype) + Filter = T.match_buffer(filt, weight_shape, dtype=dtype) + B = T.match_buffer(b, output_shape, dtype=dtype) + for n, k, bh, bw in T.grid(N, K, BH, BW): + with T.block("init"): + vn, vk, vbh, vbw = T.axis.remap("SSSS", [n, k, bh, bw]) + B[vn, vk, vbh, vbw] = T.float32(0) + for ic, fh, fw in T.grid(IC, FH, FW): + with T.block("update"): + vn, vk, vbh, vbw, vc, vfh, vfw = T.axis.remap("SSSSRRR", [n, k, bh, bw, ic, fh, fw]) + B[vn, vk, vbh, vbw] = B[vn, vk, vbh, vbw] + A[vn, vc, vbh + vfh, vbw + vfw] * Filter[vk, vc, vfh, vfw] + # fmt: on + + return conv2d + + +def schedule_tir_conv2d_nchw_oihw(sch): + update_block = sch.get_block("update") + vn, vk, vbh, vbw, vc, vfh, vfw = sch.get_loops(update_block) + sch.split(vk, factors=(None, 32)) + + +@autotvm.register_topi_compute("test/conv2d_1") +def _compute_conv2d_1(cfg, input, filter, strides, padding, dilation, out_dtype): + prim_func = compute_tir_conv2d_nchw_oihw(input.shape, filter.shape, input.dtype) + output = te.extern_primfunc([input, filter], prim_func, name="tir") + return output + + +@autotvm.register_topi_schedule("test/conv2d_1") +def _schedule_conv2d_1(cfg, outs): + s = te.create_schedule([x.op for x in outs]) + return s + + +@tvm.target.override_native_generic_func("test_conv2d_strategy") +def _tmp_strategy(attrs, inputs, out_type, target): + strategy = relay.op.OpStrategy() + if attrs.groups == 1 and attrs.data_layout == "NCHW" and attrs.kernel_layout == "OIHW": + strategy.add_implementation( + relay.op.strategy.wrap_compute_conv2d(_compute_conv2d_1), + relay.op.strategy.wrap_topi_schedule(_schedule_conv2d_1), + name="conv2d_2", + plevel=15, + ) + else: + raise ValueError("No valid strategy found") + return strategy + + +def get_conv2d(data_shape, weight_shape, **kwargs): + data = relay.var("data", shape=data_shape, dtype="float32") + weight = relay.var("weight", shape=weight_shape, dtype="float32") + conv2d = relay.nn.conv2d( + data, + weight, + **kwargs, + ) + return relay.Function([data, weight], conv2d) + + +def get_ref(data, weight, stride, padding): + return tvm.topi.testing.conv2d_nchw_python(data, weight, stride, padding) + + +def test_conv2d(): + N, IC, H, W = 1, 64, 56, 56 + OC, IC, FH, FW = 128, 64, 3, 3 + data_shape = (N, IC, H, W) + weight_shape = (OC, IC, FH, FW) + padding = (0, 0) + strides = (1, 1) + + relay_mod = tvm.IRModule.from_expr( + get_conv2d( + data_shape, + weight_shape, + padding=padding, + strides=strides, + channels=OC, + kernel_size=(FH, FW), + data_layout="NCHW", + kernel_layout="OIHW", + ) + ) + + data_np = np.random.randn(*data_shape).astype("float32") + weight_np = np.random.randn(*weight_shape).astype("float32") + + target = "llvm" + params = {"weight": weight_np} + + def schedule_fn(task, sch): + if "nn_conv2d" in task.task_name: + schedule_tir_conv2d_nchw_oihw(sch) + return True + return False + + with TempOpAttr("nn.conv2d", "FTVMStrategy", _tmp_strategy): + database = apply_fixed_schedules(relay_mod, target, params, schedule_fn) + with ApplyHistoryBest(database): + with tvm.transform.PassContext( + opt_level=3, + config={"relay.backend.use_meta_schedule": True}, + ): + lib = relay.build(relay_mod, target=target, params=params) + + dev = tvm.device(target, 0) + + runtime = tvm.contrib.graph_executor.GraphModule(lib["default"](dev)) + + runtime.set_input("data", data_np) + runtime.run() + + out = runtime.get_output(0).numpy() + + ref = get_ref(data_np, weight_np, strides, padding) + + tvm.testing.assert_allclose(out, ref, atol=1e-4, rtol=1e-4) + + +if __name__ == "__main__": + test_conv2d() diff --git a/tests/python/unittest/test_tir_te_extern_primfunc.py b/tests/python/unittest/test_tir_te_extern_primfunc.py new file mode 100644 index 000000000000..26752145620a --- /dev/null +++ b/tests/python/unittest/test_tir_te_extern_primfunc.py @@ -0,0 +1,257 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +import sys +import pytest +import numpy as np + +import tvm +import tvm.testing +from tvm import tir, te, TVMError +from tvm.script import tir as T +from tvm.arith import _ffi_api as _ffi_arith_api +from tvm.tir.schedule import _ffi_api as _ffi_schedule_api + + +# TODO(csullivan): Additional tests cases needed: +# - PrimFunc with 1 arg, inplace update +# - PrimFunc with buffer that uses custom storage_scope + + +@T.prim_func +def func_1(A: T.Buffer[(16,), "float32"], C: T.Buffer[(1,), "float32"]): + for i in T.serial( + 0, + 16, + ): + with T.block(): + B = T.alloc_buffer((1,), dtype="float32") + with T.block(): + B[0] = A[i] * T.float32(2) + with T.block(): + C[0] = C[0] + A[i] + B[0] + T.float32(1) + A[i] = B[0] + T.float32(1) + + +def verify_func_1(module): + a_np = np.random.randint(low=-128, high=127, size=(16,)).astype(np.float32) + c_np = np.zeros((1,), dtype=np.float32) + a = tvm.nd.array(a_np, device=tvm.cpu(0)) + c = tvm.nd.array(c_np, device=tvm.cpu(0)) + + module(a, c) + tvm.testing.assert_allclose(c_np + np.sum(3 * a_np + 1), c.numpy(), rtol=1e-4) + # also test in place update + tvm.testing.assert_allclose(a_np * 2 + 1, a.numpy(), rtol=1e-4) + + +@T.prim_func +def func_2( + C: T.Buffer[(1,), "float32"], A: T.Buffer[(16,), "float32"], D: T.Buffer[(2,), "float32"] +): + for i in T.serial( + 0, + 16, + ): + with T.block(): + B = T.alloc_buffer((1,), dtype="float32") + with T.block(): + B[0] = A[i] * T.float32(2) + with T.block(): + C[0] = C[0] + A[i] + B[0] + T.float32(1) + D[0] + A[i] = B[0] + T.float32(1) + D[1] + + +def verify_func_2(module): + a_np = np.random.randint(low=-128, high=127, size=(16,)).astype(np.float32) + d_np = np.random.randint(low=-128, high=127, size=(2,)).astype(np.float32) + c_np = np.zeros((1,), dtype=np.float32) + a = tvm.nd.array(a_np, device=tvm.cpu(0)) + d = tvm.nd.array(d_np, device=tvm.cpu(0)) + c = tvm.nd.array(c_np, device=tvm.cpu(0)) + + module(c, a, d) + tvm.testing.assert_allclose(c_np + np.sum(3 * a_np + 1 + d_np[0]), c.numpy(), rtol=1e-4) + tvm.testing.assert_allclose(a_np * 2 + 1 + d_np[1], a.numpy(), rtol=1e-4) + + +@T.prim_func +def func_3( + C: T.Buffer[(1,), "float32"], + A: T.Buffer[(16,), "float32"], + D: T.Buffer[(2,), "float32"], + E: T.Buffer[(16,), "float32"], + F: T.Buffer[(16,), "float32"], +): + for i in T.serial( + 0, + 16, + ): + with T.block(): + B = T.alloc_buffer((1,), dtype="float32") + with T.block(): + B[0] = A[i] * T.float32(2) + with T.block(): + E[i] = A[i] + F[i] = E[i] + 1.0 + C[0] = C[0] + A[i] + B[0] + T.float32(1) + D[0] + A[i] = B[0] + T.float32(1) + D[1] + + +def verify_func_3(module): + a_np = np.random.randint(low=-128, high=127, size=(16,)).astype(np.float32) + d_np = np.random.randint(low=-128, high=127, size=(2,)).astype(np.float32) + c_np = np.zeros((1,), dtype=np.float32) + e_np = np.zeros((16,), dtype=np.float32) + f_np = np.zeros((16,), dtype=np.float32) + a = tvm.nd.array(a_np, device=tvm.cpu(0)) + d = tvm.nd.array(d_np, device=tvm.cpu(0)) + c = tvm.nd.array(c_np, device=tvm.cpu(0)) + e = tvm.nd.array(e_np, device=tvm.cpu(0)) + f = tvm.nd.array(f_np, device=tvm.cpu(0)) + + module(c, a, d, e, f) + tvm.testing.assert_allclose(c_np + np.sum(3 * a_np + 1 + d_np[0]), c.numpy(), rtol=1e-4) + tvm.testing.assert_allclose(a_np * 2 + 1 + d_np[1], a.numpy(), rtol=1e-4) + tvm.testing.assert_allclose(a_np, e.numpy(), rtol=1e-4) + tvm.testing.assert_allclose(a_np + 1, f.numpy(), rtol=1e-4) + + +@T.prim_func +def func_4( + C: T.Buffer[(1,), "float32"], + A: T.Buffer[(16,), "float32"], + F: T.Buffer[(16,), "float32"], + D: T.Buffer[(2,), "float32"], + E: T.Buffer[(16,), "float32"], +): + for i in T.serial( + 0, + 16, + ): + with T.block(): + B = T.alloc_buffer((1,), dtype="float32") + with T.block(): + B[0] = A[i] * T.float32(2) + with T.block(): + E[i] = A[i] + F[i] = E[i] + 1.0 + C[0] = C[0] + A[i] + B[0] + T.float32(1) + D[0] + A[i] = B[0] + T.float32(1) + D[1] + + +def verify_func_4(module): + a_np = np.random.randint(low=-128, high=127, size=(16,)).astype(np.float32) + d_np = np.random.randint(low=-128, high=127, size=(2,)).astype(np.float32) + c_np = np.zeros((1,), dtype=np.float32) + e_np = np.zeros((16,), dtype=np.float32) + f_np = np.zeros((16,), dtype=np.float32) + a = tvm.nd.array(a_np, device=tvm.cpu(0)) + d = tvm.nd.array(d_np, device=tvm.cpu(0)) + c = tvm.nd.array(c_np, device=tvm.cpu(0)) + e = tvm.nd.array(e_np, device=tvm.cpu(0)) + f = tvm.nd.array(f_np, device=tvm.cpu(0)) + + module(c, a, f, d, e) + tvm.testing.assert_allclose(c_np + np.sum(3 * a_np + 1 + d_np[0]), c.numpy(), rtol=1e-4) + tvm.testing.assert_allclose(a_np * 2 + 1 + d_np[1], a.numpy(), rtol=1e-4) + tvm.testing.assert_allclose(a_np, e.numpy(), rtol=1e-4) + tvm.testing.assert_allclose(a_np + 1, f.numpy(), rtol=1e-4) + + +class TestPrimFuncs: + func, verify = tvm.testing.parameters( + [func_1, verify_func_1], + [func_2, verify_func_2], + [func_3, verify_func_3], + [func_4, verify_func_4], + ) + + def test_primfunc_call(self, func, verify): + target = tvm.target.Target("llvm") + func = tvm.build(func, target=target) + verify(func) + + def test_te_extern_call(self, func, verify): + ir_mod = tvm.IRModule.from_expr(func.with_attr("global_symbol", "main")) + prim_func = ir_mod["main"] + + input_tensors = create_input_tensors_for_primfunc(prim_func) + output = te.extern_primfunc(input_tensors, prim_func) + rt_prim_func = te.create_prim_func(tensors_from_extern_op(output, prim_func)) + tvm.ir.assert_structural_equal(tvm.lower(prim_func), tvm.lower(rt_prim_func)) + + target = tvm.target.Target("llvm") + func = tvm.build(rt_prim_func, target=target) + verify(func) + + +def tensors_from_extern_op(extern, func): + if isinstance(extern, list): + output_tensors = extern + else: + output_tensors = [extern] + output_buffers = [] + input_buffers = [] + input_tensors = [] + for ext in output_tensors: + output_buffers.extend(ext.op.output_placeholders) + input_buffers.extend(ext.op.input_placeholders) + input_tensors.extend(ext.op.input_tensors) + input_binds = dict(zip(input_buffers, input_tensors)) + output_binds = dict(zip(output_buffers, output_tensors)) + buffer_to_tensor = {**input_binds, **output_binds} + ordered_tensors = [] + for var in func.params: + buf = func.buffer_map[var] + ordered_tensors.append(buffer_to_tensor[buf]) + return ordered_tensors + + +def create_input_tensors_for_primfunc(primfunc): + access_map = {k: tuple(v) for k, v in _ffi_arith_api.DomainTouchedAccessMap(primfunc).items()} + in_buffers = [buf for buf, access in access_map.items() if len(access[0])] + out_buffers = [buf for buf, access in access_map.items() if len(access[1])] + assert in_buffers, "PrimFunc has no input buffers" + assert out_buffers, "PrimFunc has no output buffers" + + outputs = [] + inplace = [] + inputs = in_buffers + for obuf in out_buffers: + if obuf in in_buffers: + inplace.append(obuf) + else: + outputs.append(obuf) + + if not outputs: + iobuf = inplace.pop() + inputs.remove(iobuf) + outputs = [iobuf] + + def create_tensors(input_buffers): + tensors = [] + for buf in input_buffers: + t = te.placeholder(buf.shape, dtype=buf.dtype, name=buf.name + "_placeholder") + tensors.append(t) + return tensors + + return create_tensors(inputs) + + +if __name__ == "__main__": + sys.exit(pytest.main(sys.argv))