Skip to content
This repository has been archived by the owner on Nov 25, 2022. It is now read-only.

Commit

Permalink
[Fix] Fix some errors in unittests (apache#12245)
Browse files Browse the repository at this point in the history
- test_aot_legalize_packed_call.py: `T.preflattened_buffer` returns `void`
- test_tir_intrin.py: `type` here should be `buffer_type`
- test_tir_transform_flatten_buffer.py: `extents` should be `list`
- test_tir_transform_hoist_expression.py: change `tir` into `T` and register `Let` expression in `script/tir/intrin.py`
- test_tir_transform_storage_flatten.py: `T.allocate` has no argument named `strides`
  • Loading branch information
cyx-6 authored and xinetzone committed Nov 25, 2022
1 parent 9ca348b commit ea8ce7f
Show file tree
Hide file tree
Showing 6 changed files with 22 additions and 17 deletions.
5 changes: 5 additions & 0 deletions python/tvm/script/tir/intrin.py
Original file line number Diff line number Diff line change
Expand Up @@ -217,6 +217,11 @@ def Select(cond, if_body, else_body, span): # pylint: disable=invalid-name
return tvm.tir.Select(cond, if_body, else_body, span)


@register
def Let(var, value, body, span): # pylint: disable=invalid-name
return tvm.tir.Let(var, value, body, span)


@register
class EvaluateIntrin(Intrin):
def __init__(self):
Expand Down
12 changes: 6 additions & 6 deletions tests/python/unittest/test_aot_legalize_packed_call.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,11 +29,11 @@ def tvm_test_cpacked(
A: T.handle, B: T.handle, C: T.handle, device_context: T.handle
) -> T.handle:
A_0 = T.match_buffer(A, (1,), dtype="float32")
A_0pre = T.preflattened_buffer(A_0, (1,), dtype="float32")
T.preflattened_buffer(A_0, (1,), dtype="float32")
B_0 = T.match_buffer(B, (1,), dtype="float32")
B_0pre = T.preflattened_buffer(B_0, (1,), dtype="float32")
T.preflattened_buffer(B_0, (1,), dtype="float32")
C_0 = T.match_buffer(C, (1,), dtype="float32")
C_0pre = T.preflattened_buffer(C_0, (1,), dtype="float32")
T.preflattened_buffer(C_0, (1,), dtype="float32")
T.evaluate(C)

@T.prim_func
Expand Down Expand Up @@ -62,11 +62,11 @@ def tvm_test_cpacked(
A: T.handle, B: T.handle, C: T.handle, device_context: T.handle
) -> T.handle:
A_0 = T.match_buffer(A, (1,), dtype="float32")
A_0pre = T.preflattened_buffer(A_0, (1,), dtype="float32")
T.preflattened_buffer(A_0, (1,), dtype="float32")
B_0 = T.match_buffer(B, (1,), dtype="float32")
B_0pre = T.preflattened_buffer(B_0, (1,), dtype="float32")
T.preflattened_buffer(B_0, (1,), dtype="float32")
C_0 = T.match_buffer(C, (1,), dtype="float32")
C_0pre = T.preflattened_buffer(C_0, (1,), dtype="float32")
T.preflattened_buffer(C_0, (1,), dtype="float32")
T.evaluate(C)

@T.prim_func
Expand Down
8 changes: 4 additions & 4 deletions tests/python/unittest/test_tir_intrin.py
Original file line number Diff line number Diff line change
Expand Up @@ -205,7 +205,7 @@ def test_tir_fma(A: T.handle, B: T.handle, C: T.handle, d: T.handle) -> None:
elem_offset=0,
align=128,
offset_factor=1,
type="auto",
buffer_type="auto",
)
B_1 = T.match_buffer(
B,
Expand All @@ -214,7 +214,7 @@ def test_tir_fma(A: T.handle, B: T.handle, C: T.handle, d: T.handle) -> None:
elem_offset=0,
align=128,
offset_factor=1,
type="auto",
buffer_type="auto",
)
C_1 = T.match_buffer(
C,
Expand All @@ -223,7 +223,7 @@ def test_tir_fma(A: T.handle, B: T.handle, C: T.handle, d: T.handle) -> None:
elem_offset=0,
align=128,
offset_factor=1,
type="auto",
buffer_type="auto",
)
d_1 = T.match_buffer(
d,
Expand All @@ -232,7 +232,7 @@ def test_tir_fma(A: T.handle, B: T.handle, C: T.handle, d: T.handle) -> None:
elem_offset=0,
align=128,
offset_factor=1,
type="auto",
buffer_type="auto",
)
# body
for i in T.serial(0, n):
Expand Down
8 changes: 4 additions & 4 deletions tests/python/unittest/test_tir_transform_flatten_buffer.py
Original file line number Diff line number Diff line change
Expand Up @@ -137,14 +137,14 @@ def multi_alloc_func(a: T.handle, d: T.handle) -> None:

@T.prim_func
def flattened_multi_alloc_func(a: T.handle, d: T.handle) -> None:
A = T.match_buffer(a, (128), "float32")
D = T.match_buffer(d, (128), "float32")
A = T.match_buffer(a, 128, "float32")
D = T.match_buffer(d, 128, "float32")
T.preflattened_buffer(A, (4, 32), "float32", data=A.data)
T.preflattened_buffer(D, (4, 32), "float32", data=D.data)

for i, j in T.grid(4, 32):
B = T.allocate((128), "float32", "global")
C = T.allocate((128), "float32", "global")
B = T.allocate([128], "float32", "global")
C = T.allocate([128], "float32", "global")
B[i * 32 + j] = A[i * 32 + j] + 1.0
C[i * 32 + j] = A[i * 32 + j] + B[i * 32 + j]
D[i * 32 + j] = C[i * 32 + j] * 2.0
Expand Down
4 changes: 2 additions & 2 deletions tests/python/unittest/test_tir_transform_hoist_expression.py
Original file line number Diff line number Diff line change
Expand Up @@ -448,7 +448,7 @@ class TestHoistLetExpr(BaseBeforeAfter):
def before(A: T.Buffer[(4, 4), "float32"]):
for i, j in T.grid(4, 4):
x = T.var("float32")
A[i, j] = tir.Let(x, T.cast(i + 1, "float32"), 5.0 * x + T.cast(j, "float32"))
A[i, j] = T.Let(x, T.cast(i + 1, "float32"), 5.0 * x + T.cast(j, "float32"))

@T.prim_func
def expected(A: T.Buffer[(4, 4), "float32"]):
Expand All @@ -467,7 +467,7 @@ class TestSuppressHoistLetExpr(BaseBeforeAfter):
def before(A: T.Buffer[(4, 4), "float32"]):
for i, j in T.grid(4, 4):
x = T.var("float32")
A[i, j] = tir.Let(x, T.cast(i + 1, "float32"), 5.0 * x + T.cast(j, "float32"))
A[i, j] = T.Let(x, T.cast(i + 1, "float32"), 5.0 * x + T.cast(j, "float32"))

expected = before

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -95,7 +95,7 @@ def main(A_param: T.handle, C_param: T.handle):
threadIdx_x = T.env_thread("threadIdx.x")
T.launch_thread(threadIdx_x, 1)
for i in T.serial(0, 100):
B = T.allocate([4], "float32", scope="shared", strides=[1])
B = T.allocate([4], "float32", scope="shared")
with T.attr(B.data, "double_buffer_scope", 1):
for j in T.serial(0, 4):
B[j] = A[4 * i + j]
Expand Down

0 comments on commit ea8ce7f

Please sign in to comment.