Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[MetaSchedule][Test] Add unittests for GRP #12246

Merged
merged 1 commit into from
Jul 31, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
175 changes: 175 additions & 0 deletions tests/python/unittest/test_meta_schedule_space_cpu.py
Original file line number Diff line number Diff line change
Expand Up @@ -1201,6 +1201,180 @@ def gmm_2(X: T.Buffer[(1, 128, 128), "float32"], Y: T.Buffer[(1, 128, 128), "flo
)


def test_cpu_grp():
# fmt: off
@T.prim_func
def grp_0(inputs: T.Buffer[(1, 56, 56, 64), "float32"], weight: T.Buffer[(3, 3, 16, 128), "float32"], conv2d_nhwc: T.Buffer[(1, 28, 28, 128), "float32"]) -> None:
# function attr dict
T.func_attr({"global_symbol": "main", "tir.noalias": True})
# body
with T.block("root"):
T.reads()
T.writes()
T.block_attr({"meta_schedule.parallel":288, "meta_schedule.unroll_explicit":16, "meta_schedule.vectorize":64})
PadInput = T.alloc_buffer([1, 58, 58, 64], dtype="float32")
conv2d_nhwc_global = T.alloc_buffer([1, 28, 28, 128], dtype="float32")
for i0_0, i1_0, i2_0, i3_0 in T.grid(1, 7, 1, 2):
for ax0, ax1, ax2, ax3 in T.grid(1, 9, 57, 32):
with T.block("PadInput"):
i0 = T.axis.spatial(1, ax0)
i1 = T.axis.spatial(58, i1_0 * 8 + ax1)
i2 = T.axis.spatial(58, ax2)
i3 = T.axis.spatial(64, i3_0 * 32 + ax3)
T.reads(inputs[i0, i1 - 1, i2 - 1, i3])
T.writes(PadInput[i0, i1, i2, i3])
PadInput[i0, i1, i2, i3] = T.if_then_else(1 <= i1 and i1 < 57 and 1 <= i2 and i2 < 57, inputs[i0, i1 - 1, i2 - 1, i3], T.float32(0), dtype="float32")
for i0_1, i1_1, i2_1, i3_1 in T.grid(1, 4, 1, 1):
for i4_0, i5_0, i6_0, i0_2, i1_2, i2_2, i3_2, i4_1, i5_1, i6_1, i0_3, i1_3, i2_3, i3_3 in T.grid(1, 3, 8, 1, 1, 4, 4, 3, 1, 2, 1, 1, 7, 16):
with T.block("conv2d_nhwc"):
n = T.axis.spatial(1, i0_3 + i0_0 + i0_1 + i0_2)
h = T.axis.spatial(28, i1_0 * 4 + i1_1 + i1_2 + i1_3)
w = T.axis.spatial(28, i2_0 * 28 + i2_1 * 28 + i2_2 * 7 + i2_3)
co = T.axis.spatial(128, i3_0 * 64 + i3_1 * 64 + i3_2 * 16 + i3_3)
rh = T.axis.reduce(3, i4_0 * 3 + i4_1)
rw = T.axis.reduce(3, i5_0 + i5_1)
rc = T.axis.reduce(16, i6_0 * 2 + i6_1)
T.reads(PadInput[n, h * 2 + rh, w * 2 + rw, co // 32 * 16 + rc], weight[rh, rw, rc, co])
T.writes(conv2d_nhwc_global[n, h, w, co])
T.block_attr({"meta_schedule.tiling_structure":"SSRSRS"})
with T.init():
conv2d_nhwc_global[n, h, w, co] = T.float32(0)
conv2d_nhwc_global[n, h, w, co] = conv2d_nhwc_global[n, h, w, co] + PadInput[n, h * 2 + rh, w * 2 + rw, co // 32 * 16 + rc] * weight[rh, rw, rc, co]
for ax0, ax1, ax2, ax3 in T.grid(1, 1, 28, 64):
with T.block("conv2d_nhwc_global"):
v0 = T.axis.spatial(1, ax0)
v1 = T.axis.spatial(28, i1_0 * 4 + i1_1 + ax1)
v2 = T.axis.spatial(28, ax2)
v3 = T.axis.spatial(128, i3_0 * 64 + ax3)
T.reads(conv2d_nhwc_global[v0, v1, v2, v3])
T.writes(conv2d_nhwc[v0, v1, v2, v3])
conv2d_nhwc[v0, v1, v2, v3] = conv2d_nhwc_global[v0, v1, v2, v3]
@T.prim_func
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Just out of curiosity, do we need blank lines in between 👀?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

it's probably not a problem here, but if those functions are in global scope, i would prefer adding blank lines in-between

def grp_1(inputs: T.Buffer[(1, 56, 56, 64), "float32"], weight: T.Buffer[(3, 3, 16, 128), "float32"], conv2d_nhwc: T.Buffer[(1, 28, 28, 128), "float32"]) -> None:
# function attr dict
T.func_attr({"global_symbol": "main", "tir.noalias": True})
# body
with T.block("root"):
T.reads()
T.writes()
T.block_attr({"meta_schedule.parallel":288, "meta_schedule.unroll_explicit":512, "meta_schedule.vectorize":64})
PadInput = T.alloc_buffer([1, 58, 58, 64], dtype="float32")
conv2d_nhwc_global = T.alloc_buffer([1, 28, 28, 128], dtype="float32")
for i0, i1, i2, i3 in T.grid(1, 58, 58, 64):
with T.block("PadInput"):
i0_1, i1_1, i2_1, i3_1 = T.axis.remap("SSSS", [i0, i1, i2, i3])
T.reads(inputs[i0_1, i1_1 - 1, i2_1 - 1, i3_1])
T.writes(PadInput[i0_1, i1_1, i2_1, i3_1])
PadInput[i0_1, i1_1, i2_1, i3_1] = T.if_then_else(1 <= i1_1 and i1_1 < 57 and 1 <= i2_1 and i2_1 < 57, inputs[i0_1, i1_1 - 1, i2_1 - 1, i3_1], T.float32(0), dtype="float32")
for i0_0, i1_0, i2_0, i3_0 in T.grid(1, 7, 1, 2):
for i0_1_1, i1_1_1, i2_1_1, i3_1_1, i4_0, i5_0, i6_0, i0_2, i1_2, i2_2, i3_2, i4_1, i5_1, i6_1, i0_3, i1_3, i2_3, i3_3 in T.grid(1, 4, 1, 1, 1, 3, 8, 1, 1, 4, 4, 3, 1, 2, 1, 1, 7, 16):
with T.block("conv2d_nhwc"):
n = T.axis.spatial(1, i0_3 + i0_0 + i0_1_1 + i0_2)
h = T.axis.spatial(28, i1_0 * 4 + i1_1_1 + i1_2 + i1_3)
w = T.axis.spatial(28, i2_0 * 28 + i2_1_1 * 28 + i2_2 * 7 + i2_3)
co = T.axis.spatial(128, i3_0 * 64 + i3_1_1 * 64 + i3_2 * 16 + i3_3)
rh = T.axis.reduce(3, i4_0 * 3 + i4_1)
rw = T.axis.reduce(3, i5_0 + i5_1)
rc = T.axis.reduce(16, i6_0 * 2 + i6_1)
T.reads(PadInput[n, h * 2 + rh, w * 2 + rw, co // 32 * 16 + rc], weight[rh, rw, rc, co])
T.writes(conv2d_nhwc_global[n, h, w, co])
T.block_attr({"meta_schedule.tiling_structure":"SSRSRS"})
with T.init():
conv2d_nhwc_global[n, h, w, co] = T.float32(0)
conv2d_nhwc_global[n, h, w, co] = conv2d_nhwc_global[n, h, w, co] + PadInput[n, h * 2 + rh, w * 2 + rw, co // 32 * 16 + rc] * weight[rh, rw, rc, co]
for ax0, ax1, ax2, ax3 in T.grid(1, 4, 28, 64):
with T.block("conv2d_nhwc_global"):
v0 = T.axis.spatial(1, ax0)
v1 = T.axis.spatial(28, i1_0 * 4 + ax1)
v2 = T.axis.spatial(28, ax2)
v3 = T.axis.spatial(128, i3_0 * 64 + ax3)
T.reads(conv2d_nhwc_global[v0, v1, v2, v3])
T.writes(conv2d_nhwc[v0, v1, v2, v3])
conv2d_nhwc[v0, v1, v2, v3] = conv2d_nhwc_global[v0, v1, v2, v3]
@T.prim_func
def grp_2(inputs: T.Buffer[(1, 56, 56, 64), "float32"], weight: T.Buffer[(3, 3, 16, 128), "float32"], conv2d_nhwc: T.Buffer[(1, 28, 28, 128), "float32"]) -> None:
# function attr dict
T.func_attr({"global_symbol": "main", "tir.noalias": True})
# body
with T.block("root"):
T.reads()
T.writes()
T.block_attr({"meta_schedule.parallel":288, "meta_schedule.unroll_explicit":16, "meta_schedule.vectorize":64})
PadInput = T.alloc_buffer([1, 58, 58, 64], dtype="float32")
for i0_0, i1_0, i2_0, i3_0, i0_1, i1_1, i2_1, i3_1, i4_0, i5_0 in T.grid(1, 7, 1, 2, 1, 4, 1, 1, 1, 3):
for ax0, ax1, ax2, ax3 in T.grid(1, 3, 55, 32):
with T.block("PadInput"):
i0 = T.axis.spatial(1, ax0)
i1 = T.axis.spatial(58, i1_0 * 8 + i1_1 * 2 + ax1)
i2 = T.axis.spatial(58, i5_0 + ax2)
i3 = T.axis.spatial(64, i3_0 * 32 + ax3)
T.reads(inputs[i0, i1 - 1, i2 - 1, i3])
T.writes(PadInput[i0, i1, i2, i3])
PadInput[i0, i1, i2, i3] = T.if_then_else(1 <= i1 and i1 < 57 and 1 <= i2 and i2 < 57, inputs[i0, i1 - 1, i2 - 1, i3], T.float32(0), dtype="float32")
for i6_0, i0_2, i1_2, i2_2, i3_2, i4_1, i5_1, i6_1, i0_3, i1_3, i2_3, i3_3 in T.grid(8, 1, 1, 4, 4, 3, 1, 2, 1, 1, 7, 16):
with T.block("conv2d_nhwc"):
n = T.axis.spatial(1, i0_3 + i0_0 + i0_1 + i0_2)
h = T.axis.spatial(28, i1_0 * 4 + i1_1 + i1_2 + i1_3)
w = T.axis.spatial(28, i2_0 * 28 + i2_1 * 28 + i2_2 * 7 + i2_3)
co = T.axis.spatial(128, i3_0 * 64 + i3_1 * 64 + i3_2 * 16 + i3_3)
rh = T.axis.reduce(3, i4_0 * 3 + i4_1)
rw = T.axis.reduce(3, i5_0 + i5_1)
rc = T.axis.reduce(16, i6_0 * 2 + i6_1)
T.reads(PadInput[n, h * 2 + rh, w * 2 + rw, co // 32 * 16 + rc], weight[rh, rw, rc, co])
T.writes(conv2d_nhwc[n, h, w, co])
T.block_attr({"meta_schedule.tiling_structure":"SSRSRS"})
with T.init():
conv2d_nhwc[n, h, w, co] = T.float32(0)
conv2d_nhwc[n, h, w, co] = conv2d_nhwc[n, h, w, co] + PadInput[n, h * 2 + rh, w * 2 + rw, co // 32 * 16 + rc] * weight[rh, rw, rc, co]
# fmt: on
decision_0 = [
("SamplePerfectTile", [1, 1, 1, 1]),
("SamplePerfectTile", [7, 4, 1, 1]),
("SamplePerfectTile", [1, 1, 4, 7]),
("SamplePerfectTile", [2, 1, 4, 16]),
("SamplePerfectTile", [1, 3]),
("SamplePerfectTile", [3, 1]),
("SamplePerfectTile", [8, 2]),
("SampleCategorical", 1),
("SampleComputeLocation", 3),
]
decision_1 = [
("SamplePerfectTile", [1, 1, 1, 1]),
("SamplePerfectTile", [7, 4, 1, 1]),
("SamplePerfectTile", [1, 1, 4, 7]),
("SamplePerfectTile", [2, 1, 4, 16]),
("SamplePerfectTile", [1, 3]),
("SamplePerfectTile", [3, 1]),
("SamplePerfectTile", [8, 2]),
("SampleCategorical", 3),
("SampleComputeLocation", -1),
]
decision_2 = [
("SamplePerfectTile", [1, 1, 1, 1]),
("SamplePerfectTile", [7, 4, 1, 1]),
("SamplePerfectTile", [1, 1, 4, 7]),
("SamplePerfectTile", [2, 1, 4, 16]),
("SamplePerfectTile", [1, 3]),
("SamplePerfectTile", [3, 1]),
("SamplePerfectTile", [8, 2]),
("SampleCategorical", 1),
("SampleComputeLocation", 9),
]
mod = create_te_workload("GRP", 0)
actual = ms.TuneContext(
mod=mod,
target=_target(),
space_generator=ms.space_generator.PostOrderApply(),
sch_rules="default",
).generate_design_space()
check_sketches(
mod,
sketches=actual,
expected_mods=[grp_0, grp_1, grp_2],
expected_decisions=[decision_0, decision_1, decision_2],
)


if __name__ == "__main__":
test_cpu_c1d()
test_cpu_c2d()
Expand All @@ -1209,3 +1383,4 @@ def gmm_2(X: T.Buffer[(1, 128, 128), "float32"], Y: T.Buffer[(1, 128, 128), "flo
test_cpu_dep()
test_cpu_dil()
test_cpu_gmm()
test_cpu_grp()
90 changes: 90 additions & 0 deletions tests/python/unittest/test_meta_schedule_space_cuda.py
Original file line number Diff line number Diff line change
Expand Up @@ -653,6 +653,95 @@ def gmm_0(X: T.Buffer[(1, 128, 128), "float32"], Y: T.Buffer[(1, 128, 128), "flo
)


def test_cuda_grp():
# fmt: off
@T.prim_func
def grp_0(inputs: T.Buffer[(1, 56, 56, 64), "float32"], weight: T.Buffer[(3, 3, 16, 128), "float32"], conv2d_nhwc: T.Buffer[(1, 28, 28, 128), "float32"]) -> None:
# function attr dict
T.func_attr({"global_symbol": "main", "tir.noalias": True})
# body
with T.block("root"):
T.reads()
T.writes()
T.block_attr({"meta_schedule.unroll_explicit":16})
conv2d_nhwc_local = T.alloc_buffer([1, 28, 28, 128], dtype="float32", scope="local")
PadInput_shared = T.alloc_buffer([1, 58, 58, 64], dtype="float32", scope="shared")
weight_shared = T.alloc_buffer([3, 3, 16, 128], dtype="float32", scope="shared")
for i0_0_i1_0_i2_0_i3_0_fused in T.thread_binding(2, thread="blockIdx.x"):
for i0_1_i1_1_i2_1_i3_1_fused in T.thread_binding(1, thread="vthread.x"):
for i0_2_i1_2_i2_2_i3_2_fused in T.thread_binding(112, thread="threadIdx.x"):
for i4_0, i5_0, i6_0 in T.grid(3, 3, 1):
for ax0_ax1_ax2_ax3_fused in T.serial(95040):
with T.block("PadInput_shared"):
v0 = T.axis.spatial(1, 0)
v1 = T.axis.spatial(58, i0_0_i1_0_i2_0_i3_0_fused * 28 + i4_0 + ax0_ax1_ax2_ax3_fused % 95040 // 3520)
v2 = T.axis.spatial(58, i5_0 + ax0_ax1_ax2_ax3_fused % 3520 // 64)
v3 = T.axis.spatial(64, ax0_ax1_ax2_ax3_fused % 64)
T.reads(inputs[v0, v1 - 1, v2 - 1, v3])
T.writes(PadInput_shared[v0, v1, v2, v3])
T.block_attr({"meta_schedule.cooperative_fetch":2})
PadInput_shared[v0, v1, v2, v3] = T.if_then_else(1 <= v1 and v1 < 57 and 1 <= v2 and v2 < 57, inputs[v0, v1 - 1, v2 - 1, v3], T.float32(0), dtype="float32")
for ax0_ax1_ax2_ax3_fused in T.serial(2048):
with T.block("weight_shared"):
v0, v1 = T.axis.remap("SS", [i4_0, i5_0])
v2 = T.axis.spatial(16, ax0_ax1_ax2_ax3_fused // 128)
v3 = T.axis.spatial(128, ax0_ax1_ax2_ax3_fused % 128)
T.reads(weight[v0, v1, v2, v3])
T.writes(weight_shared[v0, v1, v2, v3])
T.block_attr({"meta_schedule.cooperative_fetch":1})
weight_shared[v0, v1, v2, v3] = weight[v0, v1, v2, v3]
for i4_1, i5_1, i6_1, i0_3, i1_3, i2_3, i3_3, i4_2, i5_2, i6_2, i0_4, i1_4, i2_4, i3_4 in T.grid(1, 1, 2, 1, 2, 1, 2, 1, 1, 8, 1, 7, 4, 4):
with T.block("conv2d_nhwc"):
n = T.axis.spatial(1, i0_3 + i0_4)
h = T.axis.spatial(28, i0_0_i1_0_i2_0_i3_0_fused * 14 + i1_3 * 7 + i1_4)
w = T.axis.spatial(28, i0_2_i1_2_i2_2_i3_2_fused // 16 * 4 + i2_3 * 4 + i2_4)
co = T.axis.spatial(128, i0_2_i1_2_i2_2_i3_2_fused % 16 * 8 + i3_3 * 4 + i3_4)
rh = T.axis.reduce(3, i4_0 + i4_1 + i4_2)
rw = T.axis.reduce(3, i5_2 + i5_0 + i5_1)
rc = T.axis.reduce(16, i6_0 * 16 + i6_1 * 8 + i6_2)
T.reads(PadInput_shared[n, h * 2 + rh, w * 2 + rw, co // 32 * 16 + rc], weight_shared[rh, rw, rc, co])
T.writes(conv2d_nhwc_local[n, h, w, co])
T.block_attr({"meta_schedule.thread_extent_high_inclusive":1024, "meta_schedule.thread_extent_low_inclusive":32, "meta_schedule.tiling_structure":"SSSRRSRS"})
with T.init():
conv2d_nhwc_local[n, h, w, co] = T.float32(0)
conv2d_nhwc_local[n, h, w, co] = conv2d_nhwc_local[n, h, w, co] + PadInput_shared[n, h * 2 + rh, w * 2 + rw, co // 32 * 16 + rc] * weight_shared[rh, rw, rc, co]
for ax0, ax1, ax2, ax3 in T.grid(1, 14, 4, 8):
with T.block("conv2d_nhwc_local"):
v0 = T.axis.spatial(1, ax0)
v1 = T.axis.spatial(28, i0_0_i1_0_i2_0_i3_0_fused * 14 + ax1)
v2 = T.axis.spatial(28, i0_2_i1_2_i2_2_i3_2_fused // 16 * 4 + ax2)
v3 = T.axis.spatial(128, i0_2_i1_2_i2_2_i3_2_fused % 16 * 8 + ax3)
T.reads(conv2d_nhwc_local[v0, v1, v2, v3])
T.writes(conv2d_nhwc[v0, v1, v2, v3])
conv2d_nhwc[v0, v1, v2, v3] = conv2d_nhwc_local[v0, v1, v2, v3]
# fmt: on
decision_0 = [
("SamplePerfectTile", [1, 1, 1, 1, 1]),
("SamplePerfectTile", [2, 1, 1, 2, 7]),
("SamplePerfectTile", [1, 1, 7, 1, 4]),
("SamplePerfectTile", [1, 1, 16, 2, 4]),
("SamplePerfectTile", [3, 1, 1]),
("SamplePerfectTile", [3, 1, 1]),
("SamplePerfectTile", [1, 2, 8]),
("SampleCategorical", 1),
("SampleCategorical", 0),
("SampleCategorical", 1),
]
mod = create_te_workload("GRP", 0)
actual = ms.TuneContext(
mod=mod,
target=_target(),
space_generator=ms.space_generator.PostOrderApply(),
sch_rules="default",
).generate_design_space()
check_sketches(
mod,
sketches=actual,
expected_mods=[grp_0],
expected_decisions=[decision_0],
)


if __name__ == "__main__":
test_cuda_c1d()
test_cuda_c2d()
Expand All @@ -661,3 +750,4 @@ def gmm_0(X: T.Buffer[(1, 128, 128), "float32"], Y: T.Buffer[(1, 128, 128), "flo
test_cuda_dep()
test_cuda_dil()
test_cuda_gmm()
test_cuda_grp()