Skip to content

Commit

Permalink
Browse files Browse the repository at this point in the history
  • Loading branch information
casper-hansen committed Sep 21, 2024
1 parent 7954766 commit cec576b
Showing 1 changed file with 6 additions and 7 deletions.
13 changes: 6 additions & 7 deletions awq/modules/triton/gemm.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@

# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
Expand Down Expand Up @@ -235,12 +235,9 @@ def awq_gemm_kernel(
c = accumulator.to(c_ptr.type.element_ty)
offs_cm = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)
offs_cn = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)
c_ptrs = c_ptr + N * offs_cm[:, None] + offs_cn[None, :]
c_ptrs = c_ptrs = c_ptr + pid_z * N * M + N * offs_cm[:, None] + offs_cn[None, :]
c_mask = (offs_cm[:, None] < M) & (offs_cn[None, :] < N)
if SPLIT_K == 1:
tl.store(c_ptrs, c, mask=c_mask)
else:
tl.atomic_add(c_ptrs, c, mask=c_mask)
tl.store(c_ptrs, c, mask=c_mask)


# qweights - [K , M // 8], int32
Expand Down Expand Up @@ -328,7 +325,7 @@ def awq_gemm_triton(
split_k_iters,
)

result = torch.zeros((M, N), dtype=scales.dtype, device=input.device)
result = torch.zeros((split_k_iters, M, N), dtype=scales.dtype, device=input.device)

# A = input, B = qweight, C = result
# A = M x K, B = K x N, C = M x N
Expand All @@ -348,4 +345,6 @@ def awq_gemm_triton(
SPLIT_K=split_k_iters,
)

result = result.sum(0)

return result

0 comments on commit cec576b

Please sign in to comment.