-
Notifications
You must be signed in to change notification settings - Fork 3.5k
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
[TEST/PYTHON] Add unittest folder, add a build pipeline. Rename Buffe…
…r.ptr to Buffer.data to be consistent with Array.
- Loading branch information
Showing
26 changed files
with
155 additions
and
28 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -17,3 +17,4 @@ | |
|
||
from ._base import TVMError | ||
from .api import * | ||
from .build import build |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,83 @@ | ||
"""The build pipeline in python. | ||
Eventually some of these pipelines will be moved to C++. | ||
But the first pipeline will be kept in python for ease of change and evolving. | ||
""" | ||
# pylint: disable=invalid-name, no-member, too-many-locals, too-many-arguments | ||
|
||
from . import api | ||
from . import tensor | ||
from . import schedule | ||
from . import expr | ||
from . import ir_pass | ||
from . import codegen | ||
|
||
def build(sch, | ||
args, | ||
target, | ||
name="default_function", | ||
binds=None, | ||
record_codes=None): | ||
"""Build a function with arguments as signiture. | ||
Parameters | ||
---------- | ||
sch : tvm.Schedule | ||
The schedule to be builded | ||
args : list of Buffer or Tensor or Var | ||
The argument lists to the function. | ||
target : str | ||
The target of the compilation. | ||
name : str | ||
The name of result function. | ||
binds : dict, optional | ||
Dictionary that maps the binding of symbolic buffer to Tensor. | ||
By default, a new buffer is created for each tensor in the argument. | ||
Returns | ||
------- | ||
f : Function, or pair of functions | ||
The result function. | ||
If the function requires host space allocation, | ||
a pair of functions will be returned. | ||
""" | ||
binds = {} if binds is None else binds.copy() | ||
arg_list = [] | ||
for x in args: | ||
if isinstance(x, tensor.Tensor): | ||
buf = api.Buffer(x.shape, dtype=x.dtype, name=x.op.name) | ||
assert x not in binds | ||
binds[x] = buf | ||
arg_list.append(buf) | ||
elif isinstance(x, schedule.Buffer): | ||
arg_list.append(x) | ||
elif isinstance(x, expr.Var): | ||
arg_list.append(x) | ||
else: | ||
raise ValueError("args must be Tensor, Buffer or Var") | ||
|
||
# lowering | ||
bounds = schedule.InferBound(sch) | ||
stmt = ir_pass.ScheduleOps(sch, bounds) | ||
stmt = ir_pass.StorageFlatten(stmt, binds) | ||
stmt = ir_pass.Simplify(stmt) | ||
fapi = codegen.MakeAPI(stmt, name, arg_list, len(arg_list)) | ||
fsplits = codegen.SplitHostDevice(fapi) | ||
|
||
if record_codes is not None: | ||
output_ssa = False | ||
for i, f in enumerate(fsplits): | ||
t = target if i >= 1 else "c" | ||
record_codes.append(codegen.CompileToC(f, output_ssa, t)) | ||
|
||
if target == "cuda": | ||
ret = codegen.BuildNVRTC(fsplits, "stackvm") | ||
elif target == "opencl": | ||
ret = codegen.BuildOpenCL(fsplits, "stackvm") | ||
else: | ||
raise ValueError("Unknown target %s" % target) | ||
return ret |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,44 @@ | ||
import tvm | ||
import numpy as np | ||
|
||
def test_add(): | ||
# graph | ||
n = tvm.Var('n') | ||
A = tvm.placeholder((n,), name='A') | ||
B = tvm.placeholder((n,), name='B') | ||
C = tvm.compute(A.shape, lambda *i: A(*i) + B(*i), name='C') | ||
# schedule | ||
s = tvm.Schedule(C.op) | ||
# create iter var and assign them tags. | ||
num_thread = 256 | ||
block_x = tvm.IterVar(thread_tag="blockIdx.x") | ||
thread_x = tvm.IterVar((0, num_thread), thread_tag="threadIdx.x") | ||
_, x = s[C].split(C.op.axis[0], factor=num_thread, outer=block_x) | ||
_, x = s[C].split(x, outer=thread_x) | ||
|
||
# one line to build the function. | ||
codes = [] | ||
fadd = tvm.build(s, args=[A, B, C], | ||
target="cuda", name="myadd", | ||
record_codes=codes) | ||
for c in codes: | ||
print(c) | ||
|
||
# call the function | ||
num_device = 1 | ||
for i in range(num_device): | ||
ctx = tvm.gpu(i) | ||
if not ctx.enabled: | ||
continue | ||
# launch the kernel. | ||
n = 1027 | ||
a = tvm.nd.array(np.random.uniform(size=n).astype(A.dtype), ctx) | ||
b = tvm.nd.array(np.random.uniform(size=n).astype(B.dtype), ctx) | ||
c = tvm.nd.array(np.zeros(n, dtype=C.dtype), ctx) | ||
fadd(a, b, c) | ||
np.testing.assert_allclose( | ||
c.asnumpy(), a.asnumpy() + b.asnumpy()) | ||
|
||
|
||
if __name__ == "__main__": | ||
test_add() |
File renamed without changes.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
File renamed without changes.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
File renamed without changes.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters