Skip to content

Commit

Permalink
[PYTHON] Enhance with_attr API, cleanup MakeAPILegacy in testcases (a…
Browse files Browse the repository at this point in the history
  • Loading branch information
tqchen authored and dpankratz committed Apr 24, 2020
1 parent 3838b27 commit 149f5d6
Show file tree
Hide file tree
Showing 18 changed files with 117 additions and 141 deletions.
31 changes: 31 additions & 0 deletions python/tvm/ir/function.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,8 @@
# under the License.
"""Function defintiions."""
from enum import IntEnum
import tvm.runtime

from .expr import RelayExpr
from . import _ffi_api

Expand All @@ -34,3 +36,32 @@ def attrs(self):
"""Return the attrs member of the function.
"""
return _ffi_api.BaseFunc_Attrs(self)

def with_attr(self, attr_key_or_dict, attr_value=None):
"""Create a new copy of the function and update the attribute.
Parameters
----------
attr_key_or_dict : Union[str, dict]
The attribute key to use or a dict containing multiple key value pairs.
attr_value : Object
The new attribute value.
Returns
-------
func : Function
A new copy of the function
"""
# make sure we first copy so that we can safely do copy on write
# for multiple updates.
res = _ffi_api.BaseFuncCopy(self)

if isinstance(attr_key_or_dict, dict):
for key, val in attr_key_or_dict.items():
res = _ffi_api.BaseFuncWithAttr(
res._move(), key, tvm.runtime.convert(val))
return res

return _ffi_api.BaseFuncWithAttr(
res._move(), attr_key_or_dict, tvm.runtime.convert(attr_value))
19 changes: 0 additions & 19 deletions python/tvm/relay/function.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,22 +65,3 @@ def __call__(self, *args):
Arguments.
"""
return Call(self, args, None, None)

def with_attr(self, attr_key, attr_value):
"""Create a new copy of the function and update the attribute
Parameters
----------
attr_key : str
The attribute key to use.
attr_value : Object
The new attribute value.
Returns
-------
func : Function
A new copy of the function
"""
return _ffi_api.FunctionWithAttr(
self, attr_key, convert(attr_value))
37 changes: 0 additions & 37 deletions python/tvm/testing.py
Original file line number Diff line number Diff line change
Expand Up @@ -168,41 +168,4 @@ def compare_derivative(j, n_der, grad):
x_name, grad.shape, dist, max_diff, avg_diff)


def MakeAPILegacy(stmt, name, args, num_unpacked_args, noalias):
"""Legacy adapter to build a Module from statement.
Used for migrating existing test cases only.
Parameters
----------
stmt: Stmt
The input statement.
name: str
The name of the funciton.
args: list of Buffer or Vars
The function arguments
num_unpacked_args: int
Number of unpacked arguments.
nolias: bool
Whether allow noalias.
Returns
-------
mod : IRModule
The created IRModule.
"""
assert num_unpacked_args == 0
f = tvm.tir.PrimFunc(args, stmt).with_attr(
"global_symbol", tvm.runtime.String(name))
f = f.with_attr("tir.is_entry_func", True)
if noalias:
f = f.with_attr("tir.noalias", True)
mod = tvm.IRModule({name: f})
return mod


tvm._ffi._init_api("testing", __name__)
19 changes: 0 additions & 19 deletions python/tvm/tir/function.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,22 +67,3 @@ def __init__(self,

self.__init_handle_by_constructor__(
_ffi_api.PrimFunc, param_list, body, ret_type, buffer_map, attrs)

def with_attr(self, attr_key, attr_value):
"""Create a new copy of the function and update the attribute
Parameters
----------
attr_key : str
The attribute key to use.
attr_value : Object
The new attribute value.
Returns
-------
func : Function
A new copy of the function
"""
return _ffi_api.PrimFuncWithAttr(
self, attr_key, tvm.runtime.convert(attr_value))
26 changes: 26 additions & 0 deletions src/ir/function.cc
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,14 @@
*/
#include <tvm/runtime/registry.h>
#include <tvm/ir/function.h>
// NOTE: reverse dependency on relay, tir/
// These dependencies do not happen at the interface-level,
// and are only used in minimum cases where they are clearly marked.
//
// Rationale: We calls into the type specific WithAttr function
#include <tvm/tir/function.h>
#include <tvm/relay/function.h>


namespace tvm {

Expand All @@ -31,4 +39,22 @@ TVM_REGISTER_GLOBAL("ir.BaseFunc_Attrs")
return func->attrs;
});

TVM_REGISTER_GLOBAL("ir.BaseFuncCopy")
.set_body_typed([](BaseFunc func) {
return func;
});

TVM_REGISTER_GLOBAL("ir.BaseFuncWithAttr")
.set_body_typed([](BaseFunc func, std::string key, ObjectRef value) -> BaseFunc {
if (func->IsInstance<tir::PrimFuncNode>()) {
return WithAttr(Downcast<tir::PrimFunc>(std::move(func)), key, value);
} else if (func->IsInstance<relay::FunctionNode>()) {
return WithAttr(Downcast<relay::Function>(std::move(func)), key, value);
} else {
LOG(FATAL) << "Do not support function type " << func->GetTypeKey();
return func;
}
});


} // namespace tvm
8 changes: 7 additions & 1 deletion src/ir/module.cc
Original file line number Diff line number Diff line change
Expand Up @@ -362,13 +362,19 @@ IRModule IRModule::FromExpr(
const tvm::Map<GlobalTypeVar, TypeData>& type_definitions) {
auto mod = IRModule(global_funcs, type_definitions);
BaseFunc func;
std::string gv_name = "main";

if (auto* func_node = expr.as<BaseFuncNode>()) {
func = GetRef<BaseFunc>(func_node);
if (auto opt = func->GetAttr<String>(tvm::attr::kGlobalSymbol)) {
gv_name = opt.value();
}

} else {
func = relay::Function(relay::FreeVars(expr), expr, Type(),
relay::FreeTypeVars(expr, mod), {});
}
auto main_gv = GlobalVar("main");
auto main_gv = GlobalVar(gv_name);
mod->Add(main_gv, func);
return mod;
}
Expand Down
6 changes: 0 additions & 6 deletions src/relay/ir/function.cc
Original file line number Diff line number Diff line change
Expand Up @@ -74,11 +74,5 @@ TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable)
<< node->attrs << ")";
});

TVM_REGISTER_GLOBAL("relay.ir.FunctionWithAttr")
.set_body_typed(
[](Function func, std::string name, ObjectRef ref) {
return WithAttr(std::move(func), name, ref);
});

} // namespace relay
} // namespace tvm
6 changes: 0 additions & 6 deletions src/tir/ir/function.cc
Original file line number Diff line number Diff line change
Expand Up @@ -84,11 +84,5 @@ TVM_REGISTER_GLOBAL("tir.PrimFunc")
return PrimFunc(params, body, ret_type, buffer_map, attrs);
});


TVM_REGISTER_GLOBAL("tir.PrimFuncWithAttr")
.set_body_typed([](PrimFunc func, std::string name, ObjectRef ref) {
return WithAttr(std::move(func), name, ref);
});

} // namespace tir
} // namespace tvm
3 changes: 2 additions & 1 deletion tests/python/unittest/test_runtime_extension.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,8 @@ def test_dltensor_compatible():
A[i + 1] = A[i] + 1
stmt = ib.get()

mod = tvm.testing.MakeAPILegacy(stmt, "arange", [Ab], 0, True)
mod = tvm.IRModule.from_expr(
tvm.tir.PrimFunc([Ab], stmt).with_attr("global_symbol", "arange"))
f = tvm.build(mod, target="stackvm")
a = tvm.nd.array(np.zeros(10, dtype=dtype))
aview = MyTensorView(a)
Expand Down
7 changes: 5 additions & 2 deletions tests/python/unittest/test_runtime_module_load.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,8 +57,11 @@ def save_object(names):
tvm.tir.Store(Ab.data,
tvm.tir.Load(dtype, Ab.data, i) + 1,
i + 1))
m = tvm.testing.MakeAPILegacy(stmt, "ramp", [Ab], 0, True)
m = tvm.driver.build(m, target="llvm")
mod = tvm.IRModule.from_expr(
tvm.tir.PrimFunc([Ab], stmt).with_attr(
"global_symbol", "main")
)
m = tvm.driver.build(mod, target="llvm")
for name in names:
m.save(name)

Expand Down
12 changes: 8 additions & 4 deletions tests/python/unittest/test_target_codegen_llvm.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,8 +36,11 @@ def test_llvm_intrin():
"int32", "prefetch", args, tvm.tir.Call.Intrinsic, None, 0)))
body = ib.get()

func = tvm.testing.MakeAPILegacy(body, "prefetch", [A], 0, True)
fcode = tvm.build(func, None, "llvm")
mod = tvm.IRModule.from_expr(
tvm.tir.PrimFunc([A], body).with_attr(
"global_symbol", "prefetch")
)
fcode = tvm.build(mod, None, "llvm")


def test_llvm_overloaded_intrin():
Expand Down Expand Up @@ -111,8 +114,9 @@ def test_llvm_lookup_intrin():
x = tvm.tir.call_llvm_intrin("uint8x8", "llvm.ctpop.v8i8", tvm.tir.const(1, 'uint32'), A[z])
ib.emit(x)
body = ib.get()
func = tvm.testing.MakeAPILegacy(body, "ctpop", [A], 0, True)
fcode = tvm.build(func, None, "llvm")
mod = tvm.IRModule.from_expr(
tvm.tir.PrimFunc([A], body).with_attr("global_symbol", "main"))
fcode = tvm.build(mod, None, "llvm")


def test_llvm_large_uintimm():
Expand Down
23 changes: 8 additions & 15 deletions tests/python/unittest/test_target_codegen_static_init.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,17 +20,6 @@
import numpy as np


def MakeAPILegacy(stmt, name, args, num_unpacked_args, noalias):
"""Legacy adapter to create a API"""
f = tvm.tir.PrimFunc(args, stmt).with_attr(
"global_symbol", tvm.runtime.String(name))
f = f.with_attr("tir.is_entry_func", True)
if noalias:
f = f.with_attr("tir.noalias", True)
mod = tvm.IRModule.from_expr(f)
return tvm.tir.transform.MakePackedAPI()(mod)


def test_static_callback():
dtype = 'int64'
n = te.size_var('n')
Expand All @@ -44,8 +33,11 @@ def test_static_callback():
with ib.for_range(0, n, "i", for_type="parallel") as i:
A[i] = A[i] + 1
stmt = ib.get()
fapi = tvm.testing.MakeAPILegacy(stmt, "ramp", [Ab], 0, True)
f = tvm.driver.build(fapi, target="llvm")

mod = tvm.IRModule.from_expr(
tvm.tir.PrimFunc([Ab], stmt).with_attr("global_symbol", "ramp")
)
f = tvm.driver.build(mod, target="llvm")
a = tvm.nd.array(np.zeros(10, dtype=dtype))
f(a)
f(a)
Expand All @@ -67,8 +59,9 @@ def test_cb(sh, A):
return sh

stmt = ib.get()
fapi = tvm.testing.MakeAPILegacy(stmt, "ramp", [Ab], 0, True)
f = tvm.driver.build(fapi, target="llvm")
mod = tvm.IRModule.from_expr(
tvm.tir.PrimFunc([Ab], stmt).with_attr("global_symbol", "ramp"))
f = tvm.driver.build(mod, target="llvm")
a = tvm.nd.array(np.zeros(10, dtype=dtype))
f(a)

Expand Down
34 changes: 14 additions & 20 deletions tests/python/unittest/test_target_codegen_vm_basic.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,18 +26,6 @@ def run_jit(fapi, check):
s = f.get_source()
check(f)


def MakeAPILegacy(stmt, name, args, num_unpacked_args, noalias):
"""Legacy adapter to create a API"""
f = tvm.tir.PrimFunc(args, stmt).with_attr(
"global_symbol", tvm.runtime.String(name))
f = f.with_attr("tir.is_entry_func", True)
if noalias:
f = f.with_attr("tir.noalias", True)
mod = tvm.IRModule.from_expr(f)
return tvm.tir.transform.MakePackedAPI()(mod)


def test_stack_vm_basic():
a = tvm.nd.array(np.zeros(10, dtype='float32'))
@tvm.register_func
Expand All @@ -48,8 +36,11 @@ def tvm_call_back_get_shape(shape0):
n = te.size_var('n')
Ab = tvm.tir.decl_buffer((n, ), "float32")
stmt = tvm.tir.Evaluate(tvm.tir.call_packed("tvm_call_back_get_shape", Ab.shape[0]))
fapi = tvm.testing.MakeAPILegacy(stmt, "print_shape", [Ab], 0, True)
run_jit(fapi, lambda f: f(a))

mod = tvm.IRModule.from_expr(
tvm.tir.PrimFunc([Ab], stmt).with_attr("global_symbol", "print_shape"))

run_jit(mod, lambda f: f(a))


@tvm.register_func
Expand All @@ -69,12 +60,13 @@ def test_stack_vm_loop():
ib.emit(tvm.tir.call_packed("tvm_stack_vm_print", i))

stmt = ib.get()
fapi = tvm.testing.MakeAPILegacy(stmt, "ramp", [Ab], 0, True)
mod = tvm.IRModule.from_expr(
tvm.tir.PrimFunc([Ab], stmt).with_attr("global_symbol", "ramp"))
a = tvm.nd.array(np.zeros(10, dtype=dtype))
def check(f):
f(a)
np.testing.assert_equal(a.asnumpy(), np.arange(a.shape[0]))
run_jit(fapi, check)
run_jit(mod, check)


def test_stack_vm_cond():
Expand All @@ -91,14 +83,15 @@ def test_stack_vm_cond():
A[i + 1] = A[i] + 2

stmt = ib.get()
fapi = tvm.testing.MakeAPILegacy(stmt, "test", [Ab], 0, True)
mod = tvm.IRModule.from_expr(
tvm.tir.PrimFunc([Ab], stmt).with_attr("global_symbol", "test"))
def check(f):
a = tvm.nd.array(np.zeros(10, dtype=dtype))
f(a)
y = np.arange(a.shape[0]) * 2
y[5:] -= 1
np.testing.assert_equal(a.asnumpy(), y)
run_jit(fapi, check)
run_jit(mod, check)

def test_vm_parallel():
dtype = 'int64'
Expand All @@ -110,12 +103,13 @@ def test_vm_parallel():
with ib.for_range(0, n, "i", for_type="parallel") as i:
A[i] = A[i] + 1
stmt = ib.get()
fapi = tvm.testing.MakeAPILegacy(stmt, "ramp", [Ab], 0, True)
mod = tvm.IRModule.from_expr(
tvm.tir.PrimFunc([Ab], stmt).with_attr("global_symbol", "test"))
def check(f):
a = tvm.nd.array(np.zeros(10, dtype=dtype))
f(a)
np.testing.assert_equal(a.asnumpy(), np.ones(a.shape[0]))
run_jit(fapi, check)
run_jit(mod, check)


if __name__ == "__main__":
Expand Down
2 changes: 1 addition & 1 deletion tests/python/unittest/test_tir_nodes.py
Original file line number Diff line number Diff line change
Expand Up @@ -321,7 +321,7 @@ def test_prim_func():
assert func.buffer_map[func.params[2]].same_as(b)

assert len(func.buffer_map) == 1
f2 = func.with_attr("calling_conv", 1)
f2 = func.with_attr({"calling_conv": 1, "tir.noalias": True})
assert f2.attrs["calling_conv"].value == 1
assert func.attrs is None

Expand Down
Loading

0 comments on commit 149f5d6

Please sign in to comment.