Skip to content

Commit

Permalink
larger x tiles
Browse files Browse the repository at this point in the history
  • Loading branch information
JohannesGaessler committed Jul 14, 2023
1 parent e395994 commit 60df883
Show file tree
Hide file tree
Showing 2 changed files with 13 additions and 9 deletions.
2 changes: 1 addition & 1 deletion Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -172,7 +172,7 @@ ifdef LLAMA_CUBLAS
LDFLAGS += -lcublas -lculibos -lcudart -lcublasLt -lpthread -ldl -lrt -L/usr/local/cuda/lib64 -L/opt/cuda/lib64 -L$(CUDA_PATH)/targets/x86_64-linux/lib
OBJS += ggml-cuda.o
NVCC = nvcc
NVCCFLAGS = --forward-unknown-to-host-compiler
NVCCFLAGS = --forward-unknown-to-host-compiler -use_fast_math
ifdef CUDA_DOCKER_ARCH
NVCCFLAGS += -Wno-deprecated-gpu-targets -arch=$(CUDA_DOCKER_ARCH)
else
Expand Down
20 changes: 12 additions & 8 deletions ggml-cuda.cu
Original file line number Diff line number Diff line change
Expand Up @@ -1662,24 +1662,24 @@ static __global__ void mul_mat_q(
const int tid_x = threadIdx.x;
const int tid_y = threadIdx.y;

const int row_dst_0 = blockIdx.x*WARP_SIZE;
const int row_dst_0 = 2*blockIdx.x*WARP_SIZE;
const int & row_x_0 = row_dst_0;
const int row_dst = row_dst_0 + tid_x;

const int col_dst_0 = blockIdx.y*WARP_SIZE;
const int & col_y_0 = col_dst_0;

__shared__ int tile_x_qs[WARP_SIZE][WARP_SIZE + 1];
__shared__ half tile_x_d[WARP_SIZE][WARP_SIZE/QI4_0];
__shared__ int tile_x_qs[2*WARP_SIZE][WARP_SIZE + 1];
__shared__ half tile_x_d[2*WARP_SIZE][WARP_SIZE/QI4_0];
__shared__ int tile_y_qs[WARP_SIZE][2*WARP_SIZE];
__shared__ half2 tile_y_ds[WARP_SIZE][2*WARP_SIZE/QI8_1];
float sum[4] = {0.0f};
float sum[2][4] = {0.0f};

for (int ib0 = 0; ib0 < blocks_per_row; ib0 += blocks_per_warp) {
const int ibx = tid_x / QI4_0;
const int iqsx = sizeof(int) * (tid_x % QI4_0);

for (int j = 0; j < WARP_SIZE; j += 8) {
for (int j = 0; j < 2*WARP_SIZE; j += 8) {
const block_q4_0 * __restrict__ bx = &x[(row_x_0 + j + tid_y)*blocks_per_row + ib0 + ibx];
memcpy(&tile_x_qs[j + tid_y][tid_x], &bx->qs[iqsx], sizeof(int));
tile_x_d[j + tid_y][ibx] = bx->d;
Expand All @@ -1706,9 +1706,12 @@ static __global__ void mul_mat_q(
for (int k = 0; k < WARP_SIZE; ++k) {
const int iqsy = k % (QI8_1/2) + QI8_1 * (k / (QI8_1/2));
for (int j = 0; j < WARP_SIZE; j += 8) {
sum[j/8] += vec_dot_q4_0_q8_1_impl(
sum[0][j/8] += vec_dot_q4_0_q8_1_impl(
tile_x_qs[tid_x][k], tile_y_qs[tid_y + j][iqsy + 0], tile_y_qs[tid_y + j][iqsy + (QI8_1/2)],
tile_x_d[tid_x][k / QI4_0], tile_y_ds[tid_y + j][2 * k / QI8_1]);
sum[1][j/8] += vec_dot_q4_0_q8_1_impl(
tile_x_qs[tid_x + WARP_SIZE][k], tile_y_qs[tid_y + j][iqsy + 0], tile_y_qs[tid_y + j][iqsy + (QI8_1/2)],
tile_x_d[tid_x + WARP_SIZE][k / QI4_0], tile_y_ds[tid_y + j][2 * k / QI8_1]);
}
}

Expand All @@ -1727,7 +1730,8 @@ static __global__ void mul_mat_q(
return;
}

dst[col_dst*nrows_dst + row_dst] = sum[j/8];
dst[col_dst*nrows_dst + row_dst] = sum[0][j/8];
dst[col_dst*nrows_dst + row_dst + WARP_SIZE] = sum[1][j/8];
}
}

Expand Down Expand Up @@ -2417,7 +2421,7 @@ static to_fp32_cuda_t ggml_get_to_fp32_cuda(ggml_type type) {
}

static void ggml_mul_mat_q4_0_q8_1_cuda(const void * vx, const void * vy, float * dst, const int ncols_x, const int nrows_x, const int ncols_y, const int nrows_dst, cudaStream_t stream){
const int block_num_x = (nrows_x + WARP_SIZE - 1) / WARP_SIZE;
const int block_num_x = (nrows_x + 2*WARP_SIZE - 1) / (2*WARP_SIZE);
const int block_num_y = (ncols_y + WARP_SIZE - 1) / WARP_SIZE;
const dim3 block_nums(block_num_x, block_num_y, 1);
const dim3 block_dims(WARP_SIZE, WARP_SIZE/4, 1);
Expand Down

0 comments on commit 60df883

Please sign in to comment.