Skip to content
This repository has been archived by the owner on Nov 25, 2022. It is now read-only.

Commit

Permalink
[MetaSchedule][Test] Add unittests for GMM (apache#12243)
Browse files Browse the repository at this point in the history
  • Loading branch information
junrushao authored and xinetzone committed Nov 25, 2022
1 parent bb7f48f commit 72f3e7a
Show file tree
Hide file tree
Showing 2 changed files with 205 additions and 0 deletions.
123 changes: 123 additions & 0 deletions tests/python/unittest/test_meta_schedule_space_cpu.py
Original file line number Diff line number Diff line change
Expand Up @@ -1079,10 +1079,133 @@ def dil_2(inputs: T.Buffer[(1, 224, 224, 3), "float32"], weight: T.Buffer[(7, 7,
)


def test_cpu_gmm():
# fmt: off
@T.prim_func
def gmm_0(X: T.Buffer[(1, 128, 128), "float32"], Y: T.Buffer[(1, 128, 128), "float32"], Z: T.Buffer[(1, 128, 128), "float32"]) -> None:
# function attr dict
T.func_attr({"global_symbol": "main", "tir.noalias": True})
# body
with T.block("root"):
T.reads()
T.writes()
T.block_attr({"meta_schedule.parallel":288, "meta_schedule.unroll_explicit":16, "meta_schedule.vectorize":64})
Z_global = T.alloc_buffer([1, 128, 128], dtype="float32")
for i0_0, i1_0, i2_0, i0_1, i1_1, i2_1 in T.grid(1, 4, 2, 1, 1, 8):
for i3_0, i0_2, i1_2, i2_2, i3_1, i0_3, i1_3, i2_3 in T.grid(128, 1, 16, 1, 1, 1, 2, 8):
with T.block("Z"):
b = T.axis.spatial(1, i0_0 + i0_1 + i0_2 + i0_3)
i = T.axis.spatial(128, i1_0 * 32 + i1_1 * 32 + i1_2 * 2 + i1_3)
j = T.axis.spatial(128, i2_0 * 64 + i2_1 * 8 + i2_2 * 8 + i2_3)
k = T.axis.reduce(128, i3_1 + i3_0)
T.reads(X[b, i, k], Y[b, k, j])
T.writes(Z_global[b, i, j])
T.block_attr({"meta_schedule.tiling_structure":"SSRSRS"})
with T.init():
Z_global[b, i, j] = T.float32(0)
Z_global[b, i, j] = Z_global[b, i, j] + X[b, i, k] * Y[b, k, j]
for ax0, ax1, ax2 in T.grid(1, 32, 8):
with T.block("Z_global"):
v0 = T.axis.spatial(1, ax0)
v1 = T.axis.spatial(128, i1_0 * 32 + ax1)
v2 = T.axis.spatial(128, i2_0 * 64 + i2_1 * 8 + ax2)
T.reads(Z_global[v0, v1, v2])
T.writes(Z[v0, v1, v2])
Z[v0, v1, v2] = Z_global[v0, v1, v2]
@T.prim_func
def gmm_1(X: T.Buffer[(1, 128, 128), "float32"], Y: T.Buffer[(1, 128, 128), "float32"], Z: T.Buffer[(1, 128, 128), "float32"]) -> None:
# function attr dict
T.func_attr({"global_symbol": "main", "tir.noalias": True})
# body
with T.block("root"):
T.reads()
T.writes()
T.block_attr({"meta_schedule.parallel":288, "meta_schedule.unroll_explicit":16, "meta_schedule.vectorize":64})
Z_global = T.alloc_buffer([1, 128, 128], dtype="float32")
for i0_0, i1_0, i2_0 in T.grid(1, 4, 2):
for i0_1, i1_1, i2_1, i3_0, i0_2, i1_2, i2_2, i3_1, i0_3, i1_3, i2_3 in T.grid(1, 1, 8, 128, 1, 16, 1, 1, 1, 2, 8):
with T.block("Z"):
b = T.axis.spatial(1, i0_0 + i0_1 + i0_2 + i0_3)
i = T.axis.spatial(128, i1_0 * 32 + i1_1 * 32 + i1_2 * 2 + i1_3)
j = T.axis.spatial(128, i2_0 * 64 + i2_1 * 8 + i2_2 * 8 + i2_3)
k = T.axis.reduce(128, i3_1 + i3_0)
T.reads(X[b, i, k], Y[b, k, j])
T.writes(Z_global[b, i, j])
T.block_attr({"meta_schedule.tiling_structure":"SSRSRS"})
with T.init():
Z_global[b, i, j] = T.float32(0)
Z_global[b, i, j] = Z_global[b, i, j] + X[b, i, k] * Y[b, k, j]
for ax0, ax1, ax2 in T.grid(1, 32, 64):
with T.block("Z_global"):
v0 = T.axis.spatial(1, ax0)
v1 = T.axis.spatial(128, i1_0 * 32 + ax1)
v2 = T.axis.spatial(128, i2_0 * 64 + ax2)
T.reads(Z_global[v0, v1, v2])
T.writes(Z[v0, v1, v2])
Z[v0, v1, v2] = Z_global[v0, v1, v2]
@T.prim_func
def gmm_2(X: T.Buffer[(1, 128, 128), "float32"], Y: T.Buffer[(1, 128, 128), "float32"], Z: T.Buffer[(1, 128, 128), "float32"]) -> None:
# function attr dict
T.func_attr({"global_symbol": "main", "tir.noalias": True})
# body
with T.block("root"):
T.reads()
T.writes()
T.block_attr({"meta_schedule.parallel":288, "meta_schedule.unroll_explicit":16, "meta_schedule.vectorize":64})
for i0_0, i1_0, i2_0, i0_1, i1_1, i2_1, i3_0, i0_2, i1_2, i2_2, i3_1, i0_3, i1_3, i2_3 in T.grid(1, 4, 2, 1, 1, 8, 128, 1, 16, 1, 1, 1, 2, 8):
with T.block("Z"):
b = T.axis.spatial(1, i0_0 + i0_1 + i0_2 + i0_3)
i = T.axis.spatial(128, i1_0 * 32 + i1_1 * 32 + i1_2 * 2 + i1_3)
j = T.axis.spatial(128, i2_0 * 64 + i2_1 * 8 + i2_2 * 8 + i2_3)
k = T.axis.reduce(128, i3_1 + i3_0)
T.reads(X[b, i, k], Y[b, k, j])
T.writes(Z[b, i, j])
T.block_attr({"meta_schedule.tiling_structure":"SSRSRS"})
with T.init():
Z[b, i, j] = T.float32(0)
Z[b, i, j] = Z[b, i, j] + X[b, i, k] * Y[b, k, j]
# fmt: on
decision_0 = [
("SamplePerfectTile", [1, 1, 1, 1]),
("SamplePerfectTile", [4, 1, 16, 2]),
("SamplePerfectTile", [2, 8, 1, 8]),
("SamplePerfectTile", [128, 1]),
("SampleCategorical", 1),
]
decision_1 = [
("SamplePerfectTile", [1, 1, 1, 1]),
("SamplePerfectTile", [4, 1, 16, 2]),
("SamplePerfectTile", [2, 8, 1, 8]),
("SamplePerfectTile", [128, 1]),
("SampleCategorical", 1),
]
decision_2 = [
("SamplePerfectTile", [1, 1, 1, 1]),
("SamplePerfectTile", [4, 1, 16, 2]),
("SamplePerfectTile", [2, 8, 1, 8]),
("SamplePerfectTile", [128, 1]),
("SampleCategorical", 1),
]
mod = create_te_workload("GMM", 0)
actual = ms.TuneContext(
mod=mod,
target=_target(),
space_generator=ms.space_generator.PostOrderApply(),
sch_rules="default",
).generate_design_space()
check_sketches(
mod,
sketches=actual,
expected_mods=[gmm_0, gmm_1, gmm_2],
expected_decisions=[decision_0, decision_1, decision_2],
)


if __name__ == "__main__":
test_cpu_c1d()
test_cpu_c2d()
test_cpu_c3d()
test_cpu_cap()
test_cpu_dep()
test_cpu_dil()
test_cpu_gmm()
82 changes: 82 additions & 0 deletions tests/python/unittest/test_meta_schedule_space_cuda.py
Original file line number Diff line number Diff line change
Expand Up @@ -572,10 +572,92 @@ def dil_0(inputs: T.Buffer[(1, 224, 224, 3), "float32"], weight: T.Buffer[(7, 7,
)


def test_cuda_gmm():
# fmt: off
@T.prim_func
def gmm_0(X: T.Buffer[(1, 128, 128), "float32"], Y: T.Buffer[(1, 128, 128), "float32"], Z: T.Buffer[(1, 128, 128), "float32"]) -> None:
# function attr dict
T.func_attr({"global_symbol": "main", "tir.noalias": True})
# body
with T.block("root"):
T.reads()
T.writes()
T.block_attr({"meta_schedule.unroll_explicit":1024})
Z_local = T.alloc_buffer([1, 128, 128], dtype="float32", scope="local")
X_shared = T.alloc_buffer([1, 128, 128], dtype="float32", scope="shared")
Y_shared = T.alloc_buffer([1, 128, 128], dtype="float32", scope="shared")
for i0_0_i1_0_i2_0_fused in T.thread_binding(1, thread="blockIdx.x"):
for i0_1_i1_1_i2_1_fused in T.thread_binding(32, thread="vthread.x"):
for i0_2_i1_2_i2_2_fused in T.thread_binding(2, thread="threadIdx.x"):
for i3_0 in T.serial(1):
for ax0_ax1_ax2_fused in T.serial(16384):
with T.block("X_shared"):
v0 = T.axis.spatial(1, 0)
v1 = T.axis.spatial(128, ax0_ax1_ax2_fused // 128)
v2 = T.axis.spatial(128, ax0_ax1_ax2_fused % 128)
T.reads(X[v0, v1, v2])
T.writes(X_shared[v0, v1, v2])
T.block_attr({"meta_schedule.cooperative_fetch":2})
X_shared[v0, v1, v2] = X[v0, v1, v2]
for ax0_ax1_ax2_fused in T.serial(16384):
with T.block("Y_shared"):
v0 = T.axis.spatial(1, 0)
v1 = T.axis.spatial(128, ax0_ax1_ax2_fused // 128)
v2 = T.axis.spatial(128, ax0_ax1_ax2_fused % 128)
T.reads(Y[v0, v1, v2])
T.writes(Y_shared[v0, v1, v2])
T.block_attr({"meta_schedule.cooperative_fetch":1})
Y_shared[v0, v1, v2] = Y[v0, v1, v2]
for i3_1, i0_3, i1_3, i2_3, i3_2, i0_4, i1_4, i2_4 in T.grid(32, 1, 2, 64, 4, 1, 2, 1):
with T.block("Z"):
b = T.axis.spatial(1, i0_4 + i0_3)
i = T.axis.spatial(128, i0_1_i1_1_i2_1_fused * 4 + i1_3 * 2 + i1_4)
j = T.axis.spatial(128, i2_4 + i0_2_i1_2_i2_2_fused * 64 + i2_3)
k = T.axis.reduce(128, i3_0 * 128 + i3_1 * 4 + i3_2)
T.reads(X_shared[b, i, k], Y_shared[b, k, j])
T.writes(Z_local[b, i, j])
T.block_attr({"meta_schedule.thread_extent_high_inclusive":1024, "meta_schedule.thread_extent_low_inclusive":32, "meta_schedule.tiling_structure":"SSSRRSRS"})
with T.init():
Z_local[b, i, j] = T.float32(0)
Z_local[b, i, j] = Z_local[b, i, j] + X_shared[b, i, k] * Y_shared[b, k, j]
for ax0, ax1, ax2 in T.grid(1, 4, 64):
with T.block("Z_local"):
v0 = T.axis.spatial(1, ax0)
v1 = T.axis.spatial(128, i0_1_i1_1_i2_1_fused * 4 + ax1)
v2 = T.axis.spatial(128, i0_2_i1_2_i2_2_fused * 64 + ax2)
T.reads(Z_local[v0, v1, v2])
T.writes(Z[v0, v1, v2])
Z[v0, v1, v2] = Z_local[v0, v1, v2]
# fmt: on
decision_0 = [
("SamplePerfectTile", [1, 1, 1, 1, 1]),
("SamplePerfectTile", [1, 32, 1, 2, 2]),
("SamplePerfectTile", [1, 1, 2, 64, 1]),
("SamplePerfectTile", [1, 32, 4]),
("SampleCategorical", 1),
("SampleCategorical", 0),
("SampleCategorical", 4),
]
mod = create_te_workload("GMM", 0)
actual = ms.TuneContext(
mod=mod,
target=_target(),
space_generator=ms.space_generator.PostOrderApply(),
sch_rules="default",
).generate_design_space()
check_sketches(
mod,
sketches=actual,
expected_mods=[gmm_0],
expected_decisions=[decision_0],
)


if __name__ == "__main__":
test_cuda_c1d()
test_cuda_c2d()
test_cuda_c3d()
test_cuda_cap()
test_cuda_dep()
test_cuda_dil()
test_cuda_gmm()

0 comments on commit 72f3e7a

Please sign in to comment.