Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Enable gemv schedule for adreno #16932

Merged
merged 7 commits into from
Apr 29, 2024
Merged
Show file tree
Hide file tree
Changes from 5 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
199 changes: 197 additions & 2 deletions python/tvm/dlight/gpu/gemv.py
Original file line number Diff line number Diff line change
Expand Up @@ -208,8 +208,17 @@ def apply( # pylint: disable=too-many-locals,too-many-branches,too-many-return-
elif is_inner_reduction:
self.sch_inner_reduction(sch, target, block, vector_input_buffers, epilogue)
return sch
elif target.kind.name == "opencl":
ret = self.sch_outer_reduction(sch, target, block, vector_input_buffers, epilogue)
if ret is None:
return self.sch_outer_reduction_fallback(
sch, target, block, vector_input_buffers, epilogue
)
return sch
else:
return self.sch_outer_reduction(sch, target, block, vector_input_buffers, epilogue)
return self.sch_outer_reduction_fallback(
sch, target, block, vector_input_buffers, epilogue
)

def sch_inner_reduction( # pylint: disable=too-many-arguments, invalid-name, unused-argument
self,
Expand Down Expand Up @@ -486,7 +495,7 @@ def apply(
LOAD_V_SHARED = False
LOAD_V_VEC = -1
UNROLL = 8
TS, TR = 2, 32
TS, TR = 2, 64
elif target.kind.name == "vulkan":
VEC_C = 4
LOAD_V_SHARED = True
Expand Down Expand Up @@ -551,6 +560,192 @@ def sch_outer_reduction( # pylint: disable=too-many-arguments, invalid-name, un
block: tir.schedule.BlockRV,
vector_input_buffers: List[tir.Buffer],
epilogue_info: Optional[BlockInfo],
):
"""Schedule the inner reduction block."""
krishnaraj36 marked this conversation as resolved.
Show resolved Hide resolved

def get_max_factor(n, factors):
factors = sorted(factors, reverse=True)
for factor in factors:
if n % factor == 0:
return factor
return 1

def apply(
sch: tir.Schedule,
gemv,
TAG_S,
TAG_R,
TS,
TR,
SCALE_PACK,
DEC_PACK,
VEC_LOAD,
VEC_C,
LOAD_V_SHARED,
LOAD_V_VEC,
UNROLL,
LOAD_V_TILE,
):
# rfactor: reduce to tx * vec_c
batch, s, r, c = sch.get_loops(block=gemv)
s = sch.fuse(batch, s)
r = sch.fuse(r, c)
bx, ts = sch.split(s, factors=[None, TS], preserve_unit_iters=True)
r, v_tile, tr, tile_r, vec_c = sch.split(
r, factors=[None, LOAD_V_TILE, TR, SCALE_PACK, DEC_PACK], preserve_unit_iters=True
)
sch.reorder(bx, ts, r, v_tile, tile_r, tr, vec_c)
tr_vec_c = sch.fuse(tr, vec_c)
rf = sch.rfactor(tr_vec_c, 0)

# rfactor: reduce to tx
bx, ts, tr_vec_c = sch.get_loops(block=gemv)
tr, vec_c = sch.split(tr_vec_c, factors=[TR, None], preserve_unit_iters=True)
rf2 = sch.rfactor(tr, 0)

# bind, vectorize compute
bx, ts, r, v_tile, tile_r, tr_vec_c = sch.get_loops(block=rf)
tr, vec_c = sch.split(tr_vec_c, factors=[TR, DEC_PACK])
sch.reorder(bx, ts, tr, r, v_tile, tile_r, vec_c)
# sch.bind(batch, "blockIdx.z")
sch.bind(bx, "blockIdx.x")
sch.bind(ts, "threadIdx.x")
sch.bind(tr, "threadIdx.y")
sch.vectorize(vec_c)

# decompose independent scale read to outer loop
block_rf_stmt = sch.get(rf)
if len(block_rf_stmt.reads) >= 3:
As_local = sch.cache_read(rf, read_buffer_index=2, storage_scope="local")
sch.compute_at(As_local, v_tile, preserve_unit_loops=True)
# *tile_thr, vec_s = sch.get_loops(block=As_local)
# sch.vectorize(vec_s)

Aq_local = sch.cache_read(rf, read_buffer_index=1, storage_scope="local")
sch.compute_at(Aq_local, tile_r, preserve_unit_loops=True)
# *tile_thr, vec_s = sch.get_loops(block=Aq_local)
# sch.vectorize(vec_s)

if LOAD_V_SHARED:
V_shared = sch.cache_read(rf, read_buffer_index=0, storage_scope="shared")
sch.compute_at(V_shared, r, preserve_unit_loops=True)
l = sch.get_loops(block=V_shared)[-1]
_, v_tile, tx, ty, vec = sch.split(
l, factors=[None, LOAD_V_TILE, TS, TR, LOAD_V_VEC], preserve_unit_iters=True
)
sch.bind(ty, "threadIdx.y")
sch.bind(tx, "threadIdx.x")
sch.vectorize(vec)

# reduce tile_s * tr * vec to tile_s * tr
sch.reverse_compute_at(rf2, loop=bx, preserve_unit_loops=True)
tr, vec_c, ts = sch.get_loops(block=rf2)[1:]
sch.reorder(ts, tr, vec_c)
sch.bind(ts, "threadIdx.x")
sch.bind(tr, "threadIdx.y")

# reduce tile_s * tr to tile_s
sch.reverse_compute_at(gemv, loop=bx, preserve_unit_loops=True)
tr, ts = sch.get_loops(block=gemv)[1:]
sch.reorder(ts, tr)
sch.bind(ts, "threadIdx.x")
sch.bind(tr, "threadIdx.y")

sch.decompose_reduction(rf, loop=sch.get_loops(block=rf)[2])
sch.decompose_reduction(rf2, loop=sch.get_loops(block=rf2)[-1])

sch.set_scope(rf, buffer_index=0, storage_scope="local")
sch.set_scope(rf2, buffer_index=0, storage_scope="local")

sch.annotate(
block_or_loop=sch.get_loops(rf2)[3],
ann_key="pragma_auto_unroll_max_step",
ann_val=DEC_PACK,
)
sch.annotate(
krishnaraj36 marked this conversation as resolved.
Show resolved Hide resolved
block_or_loop=sch.get_loops(rf2)[3], ann_key="pragma_unroll_explicit", ann_val=1
)

# Schedule epilogue
if epilogue_info is not None:
epilogue = epilogue_info.block_rv
if is_broadcast_epilogue(sch, block, epilogue):
sch.reverse_compute_at(epilogue, bx)
sch.set_scope(block, 0, "shared")
_, _, *s = sch.get_loops(epilogue) # pylint: disable=invalid-name
_, tx = sch.split(sch.fuse(*s), factors=[None, TX])
sch.bind(tx, "threadIdx.x")
else:
sch.reverse_compute_at(epilogue, bx, preserve_unit_loops=True)
ts_tile_s = sch.fuse(*sch.get_loops(epilogue)[1:])
ts_tile_s = sch.get_loops(epilogue)[-1]
ts, _ = sch.split(ts_tile_s, factors=[TS, None], preserve_unit_iters=True)
sch.bind(ts, "threadIdx.x")
sch.set_scope(block, 0, "local")
return sch
# return sch.mod["main"].with_attr("tir.is_scheduled", 1)
krishnaraj36 marked this conversation as resolved.
Show resolved Hide resolved

# Specify the `len_tx` and `len_ty` according to the loop extent
batch, s, r, c = sch.get_loops(block=block)
_, len_s, len_r, len_c = (
get_extent(sch, batch),
get_extent(sch, s),
get_extent(sch, r),
get_extent(sch, c),
)

TAG_S, TAG_R = "threadIdx.x", "threadIdx.y"
VEC_C = 1
UNROLL = 4
TS, TR = 64, 4
DEC_PACK = 8
SCALE_PACK = 4
LOAD_V_SHARED = False
LOAD_V_VEC = 4
LOAD_V_TILE = 8

if LOAD_V_SHARED is False:
LOAD_V_TILE = 1

if not isinstance(len_r, int):
return None

if isinstance(len_s, int) and len_s > 32000:
return None

_, TILE_R = (
1,
len_c
if len_c > 1
else max(get_max_factor(len_r, [TR * 1, TR * 2, TR * 4, TR * 8]) // TR, 1),
)
LOAD_V_VEC = min(get_max_factor(TILE_R, [1, 2, 4, 8]), LOAD_V_VEC)
VEC_LOAD = 1

return apply(
sch,
gemv=block,
TAG_S=TAG_S,
TAG_R=TAG_R,
TS=TS,
TR=TR,
SCALE_PACK=SCALE_PACK,
DEC_PACK=DEC_PACK,
VEC_LOAD=VEC_LOAD,
VEC_C=VEC_C,
LOAD_V_SHARED=LOAD_V_SHARED,
LOAD_V_VEC=LOAD_V_VEC,
UNROLL=UNROLL,
LOAD_V_TILE=LOAD_V_TILE,
)

def sch_outer_reduction_fallback( # pylint: disable=too-many-arguments, invalid-name, unused-argument
self,
sch: tir.Schedule,
target: Target,
block: tir.schedule.BlockRV,
vector_input_buffers: List[tir.Buffer],
epilogue_info: Optional[BlockInfo],
):
"""Schedule the outer reduction block."""
# NOTE: Only Android is supported so far
Expand Down
2 changes: 1 addition & 1 deletion python/tvm/dlight/gpu/matmul.py
Original file line number Diff line number Diff line change
Expand Up @@ -777,7 +777,7 @@ def get_configs(self, target: Target) -> Config:
elif target.kind.name == "opencl" and "android" in str(target.host):
return Matmul.Config(
block_size_x=8,
block_size_y=8,
block_size_y=16,
vthread_x=1,
vthread_y=1,
micro_size_x=8,
Expand Down
Loading
Loading