Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

cuda : optimize argmax #10441

Merged
merged 6 commits into from
Nov 21, 2024
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
95 changes: 53 additions & 42 deletions ggml/src/ggml-cuda/argmax.cu
Original file line number Diff line number Diff line change
@@ -1,57 +1,68 @@
#include "common.cuh"
#include <algorithm>
#include <cstdint>

#include "argmax.cuh"
#include "common.cuh"
#include "sum.cuh"

#include <cstdint>
static __global__ void argmax_f32(const float * x, int32_t * dst, const int64_t ncols) {
const int64_t row = blockIdx.x;

static __global__ void argmax_f32(
const float * x, int32_t * dst, const int64_t ncols, const int64_t nrows) {
float maxval = -FLT_MAX;
int argmax = -1;
const float * rowx = x + row * ncols;
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looking at the code again, I think either 64 bit should be used for the ne00 dimension or there should be an assert that 32 bit is enough.

Copy link
Member Author

@slaren slaren Nov 21, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The output is int32, so it would definitely not work with ne00 larger than INT_MAX. In that case it might make more sense to add the assert to ggml_argmax instead. Other arg* functions will have the same issue.


int argmax_thread = 0;
const int64_t row0 = (int64_t)blockIdx.x*WARP_SIZE;
for (int32_t col = threadIdx.x; col < ncols; col += blockDim.x) {
const float val = rowx[col];
if (val > maxval) {
maxval = val;
argmax = col;
}
}

#pragma unroll
for (int64_t row1 = 0; row1 < WARP_SIZE; ++row1) {
const int64_t row = row0 + row1;

if (row >= nrows) {
break;
for (int offset = 16; offset > 0; offset >>= 1) {
const float val = __shfl_xor_sync(0xFFFFFFFF, maxval, offset, WARP_SIZE);
const int col = __shfl_xor_sync(0xFFFFFFFF, argmax, offset, WARP_SIZE);
if (val > maxval) {
maxval = val;
argmax = col;
}
}
Comment on lines +27 to 30
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In retrospect it probably makes more sense to do it like this; conditional statements are problematic for code optimization since they prevent the compiler from reordering instructions but there isn't much to do in one loop iteration anyways.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I couldn't measure a meaningful difference in performance and this should be easier to understand and maintain. Maybe in some hardware it would make a difference? I also would expect the compiler to be able to optimize simple conditionals like this, but that may be expecting too much.


float maxval = -FLT_MAX;
int argmax = -1;

for (int32_t col = threadIdx.x; col < ncols; col += WARP_SIZE) {
const float val = x[row*ncols + col];
const int bigger = val > maxval;
const int not_bigger = bigger ^ 0x00000001;

maxval = maxval*not_bigger + val*bigger;
argmax = argmax*not_bigger + col*bigger;
const int n_warps = (blockDim.x + WARP_SIZE - 1) / WARP_SIZE;
const int lane_id = threadIdx.x % WARP_SIZE;
const int warp_id = threadIdx.x / WARP_SIZE;
if (n_warps > 1) {
constexpr int max_warps = 1024 / WARP_SIZE;
__shared__ float shared_maxval[max_warps];
__shared__ int shared_argmax[max_warps];
if (lane_id == 0) {
shared_maxval[warp_id] = maxval;
shared_argmax[warp_id] = argmax;
}

__syncthreads();

if (warp_id == 0 && lane_id < n_warps) {
maxval = shared_maxval[lane_id];
argmax = shared_argmax[lane_id];
const unsigned int mask = (1u << n_warps) - 1u;
#pragma unroll
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's probably faster to just have all threads participate in the shuffle unconditionally.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The reason for doing this is that if there are less than 32 warps, then some values will not be written to the shared memory, so they should not be used.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

My suggestion would be to just have the first warp clear the memory and then do a __syncthreads before reading again.

for (int mask = 16; mask > 0; mask >>= 1) {
const float val = __shfl_xor_sync(0xFFFFFFFF, maxval, mask, WARP_SIZE);
const int col = __shfl_xor_sync(0xFFFFFFFF, argmax, mask, WARP_SIZE);
const int bigger = val > maxval;
const int not_bigger = bigger ^ 0x00000001;

maxval = maxval*not_bigger + val*bigger;
argmax = argmax*not_bigger + col*bigger;
for (int offset = 16; offset > 0; offset >>= 1) {
const float val = __shfl_xor_sync(mask, maxval, offset, WARP_SIZE);
const int col = __shfl_xor_sync(mask, argmax, offset, WARP_SIZE);
if (val > maxval) {
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The CUDA documentation says:

Threads may only read data from another thread which is actively participating in the __shfl_sync() command. If the target thread is inactive, the retrieved value is undefined.

It doesn't explicitly mention __shfl_xor_sync but I suspect that the same hardware is used and that thus the same limitations apply.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looking at the PTX documentation the behavior is definitely undefined.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't understand what's the undefined behavior here, can you elaborate? The mask is set such as only the threads participating in the sync are used.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The PTX documentations reads:

Note that results are undefined if a thread sources a register from an inactive thread or a thread that is not in membermask.

The problem is that even if you limit the participating threads via the mask they are still retrieving data from threads outside the mask. You would have to dynamically change the values of offset and in the most general case where n_warps is not a power of 2 you would need to use instructions other than __shfl_xor_sync.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ok, thanks. It should be fixed now.

maxval = val;
argmax = col;
}
}
}

const int store = row1 == threadIdx.x;
argmax_thread += store*argmax;
}

const int row = row0 + threadIdx.x;

if (row >= nrows) {
return;
if (warp_id == 0 && lane_id == 0) {
dst[row] = argmax;
}

Comment on lines +64 to 66
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

My experience is that conditional returns/continues are faster than conditional writes but it probably doesn't matter much.

dst[row] = argmax_thread;
}

void ggml_cuda_argmax(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
Expand All @@ -70,10 +81,10 @@ void ggml_cuda_argmax(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {

cudaStream_t stream = ctx.stream();

const int64_t num_blocks = (nrows + WARP_SIZE - 1) / WARP_SIZE;

const dim3 blocks_dim(WARP_SIZE, 1, 1);
const int64_t num_blocks = nrows;
const int64_t num_threads = std::min<int64_t>(1024, (ne00 + WARP_SIZE - 1) / WARP_SIZE * WARP_SIZE);
const dim3 blocks_dim(num_threads, 1, 1);
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is going to be efficient for 32 <= ne00 <= 1024 and ne00 >> 1024 but inefficient for 1024 < ne00 <= 4096. And in general, if you have a variable block size you should make it a template parameter.

const dim3 blocks_num(num_blocks, 1, 1);

argmax_f32<<<blocks_num, blocks_dim, 0, stream>>>(src0_d, dst_d, ne00, nrows);
argmax_f32<<<blocks_num, blocks_dim, 0, stream>>>(src0_d, dst_d, ne00);
}
30 changes: 15 additions & 15 deletions ggml/src/ggml-cuda/common.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -180,26 +180,26 @@ static __device__ __forceinline__ int warp_reduce_sum(int x) {
return __reduce_add_sync(0xffffffff, x);
#else
#pragma unroll
for (int mask = 16; mask > 0; mask >>= 1) {
x += __shfl_xor_sync(0xffffffff, x, mask, 32);
for (int offset = 16; offset > 0; offset >>= 1) {
x += __shfl_xor_sync(0xffffffff, x, offset, 32);
}
return x;
#endif // !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__)) && __CUDA_ARCH__ >= CC_AMPERE
}

static __device__ __forceinline__ float warp_reduce_sum(float x) {
#pragma unroll
for (int mask = 16; mask > 0; mask >>= 1) {
x += __shfl_xor_sync(0xffffffff, x, mask, 32);
for (int offset = 16; offset > 0; offset >>= 1) {
x += __shfl_xor_sync(0xffffffff, x, offset, 32);
}
return x;
}

static __device__ __forceinline__ float2 warp_reduce_sum(float2 a) {
#pragma unroll
for (int mask = 16; mask > 0; mask >>= 1) {
a.x += __shfl_xor_sync(0xffffffff, a.x, mask, 32);
a.y += __shfl_xor_sync(0xffffffff, a.y, mask, 32);
for (int offset = 16; offset > 0; offset >>= 1) {
a.x += __shfl_xor_sync(0xffffffff, a.x, offset, 32);
a.y += __shfl_xor_sync(0xffffffff, a.y, offset, 32);
}
return a;
}
Expand All @@ -209,16 +209,16 @@ static __device__ __forceinline__ half2 warp_reduce_sum(half2 a) {

#if defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__)
#pragma unroll
for (int mask = 16; mask > 0; mask >>= 1) {
const half2 a_other = __shfl_xor_sync(0xffffffff, a, mask, 32);
for (int offset = 16; offset > 0; offset >>= 1) {
const half2 a_other = __shfl_xor_sync(0xffffffff, a, offset, 32);
reinterpret_cast<half&>(a.x) += __low2half(a_other);
reinterpret_cast<half&>(a.y) += __high2half(a_other);
}
return a;
#else
#pragma unroll
for (int mask = 16; mask > 0; mask >>= 1) {
a = __hadd2(a, __shfl_xor_sync(0xffffffff, a, mask, 32));
for (int offset = 16; offset > 0; offset >>= 1) {
a = __hadd2(a, __shfl_xor_sync(0xffffffff, a, offset, 32));
}
return a;
#endif // defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__)
Expand All @@ -231,8 +231,8 @@ static __device__ __forceinline__ half2 warp_reduce_sum(half2 a) {

static __device__ __forceinline__ float warp_reduce_max(float x) {
#pragma unroll
for (int mask = 16; mask > 0; mask >>= 1) {
x = fmaxf(x, __shfl_xor_sync(0xffffffff, x, mask, 32));
for (int offset = 16; offset > 0; offset >>= 1) {
x = fmaxf(x, __shfl_xor_sync(0xffffffff, x, offset, 32));
}
return x;
}
Expand Down Expand Up @@ -275,8 +275,8 @@ static __device__ __forceinline__ half2 ggml_cuda_hmax2(const half2 a, const hal
static __device__ __forceinline__ half2 warp_reduce_max(half2 x) {
#if !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__)) && __CUDA_ARCH__ >= CC_PASCAL
#pragma unroll
for (int mask = 16; mask > 0; mask >>= 1) {
x = ggml_cuda_hmax2(x, __shfl_xor_sync(0xffffffff, x, mask, 32));
for (int offset = 16; offset > 0; offset >>= 1) {
x = ggml_cuda_hmax2(x, __shfl_xor_sync(0xffffffff, x, offset, 32));
}
return x;
#else
Expand Down
8 changes: 4 additions & 4 deletions ggml/src/ggml-cuda/quantize.cu
Original file line number Diff line number Diff line change
Expand Up @@ -69,8 +69,8 @@ static __global__ void quantize_mmq_q8_1(

// Exchange max. abs. value between vals_per_scale/4 threads.
#pragma unroll
for (int mask = vals_per_scale/8; mask > 0; mask >>= 1) {
amax = fmaxf(amax, __shfl_xor_sync(0xFFFFFFFF, amax, mask, WARP_SIZE));
for (int offset = vals_per_scale/8; offset > 0; offset >>= 1) {
amax = fmaxf(amax, __shfl_xor_sync(0xFFFFFFFF, amax, offset, WARP_SIZE));
}

float sum;
Expand All @@ -79,8 +79,8 @@ static __global__ void quantize_mmq_q8_1(

// Exchange calculate sum across vals_per_sum/4 threads.
#pragma unroll
for (int mask = vals_per_sum/8; mask > 0; mask >>= 1) {
sum += __shfl_xor_sync(0xFFFFFFFF, sum, mask, WARP_SIZE);
for (int offset = vals_per_sum/8; offset > 0; offset >>= 1) {
sum += __shfl_xor_sync(0xFFFFFFFF, sum, offset, WARP_SIZE);
}
}

Expand Down
29 changes: 29 additions & 0 deletions tests/test-backend-ops.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1155,6 +1155,26 @@ struct test_argmax : public test_case {
return out;
}

void initialize_tensors(ggml_context * ctx) override {
std::random_device rd;
std::default_random_engine rng(rd());
for (ggml_tensor * t = ggml_get_first_tensor(ctx); t != NULL; t = ggml_get_next_tensor(ctx, t)) {
if (t->type == GGML_TYPE_F32) {
// initialize with unique values to avoid ties
for (int64_t r = 0; r < ggml_nrows(t); r++) {
std::vector<float> data(t->ne[0]);
for (int i = 0; i < t->ne[0]; i++) {
data[i] = i;
}
std::shuffle(data.begin(), data.end(), rng);
ggml_backend_tensor_set(t, data.data(), r * t->nb[1], t->ne[0] * sizeof(float));
}
} else {
init_tensor_uniform(t);
}
}
}

double max_nmse_err() override {
return 0.0;
}
Expand Down Expand Up @@ -3441,6 +3461,11 @@ static std::vector<std::unique_ptr<test_case>> make_test_cases_eval() {
test_cases.emplace_back(new test_conv_transpose_1d({2,1,1,1}, {3,1,1,1}, 1, 0, 1));

test_cases.emplace_back(new test_argmax());
test_cases.emplace_back(new test_argmax(GGML_TYPE_F32, {32, 1, 1, 1}));
test_cases.emplace_back(new test_argmax(GGML_TYPE_F32, {100, 10, 1, 1}));
test_cases.emplace_back(new test_argmax(GGML_TYPE_F32, {1024, 10, 1, 1}));
test_cases.emplace_back(new test_argmax(GGML_TYPE_F32, {2000, 10, 1, 1}));
Comment on lines +3464 to +3467
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You may want to also check the case with ne01 and ne00 flipped where whether or not the writes are coalesced makes a comparatively larger difference. But that would be the case with a very large batch size and few classes and especially with language models that have large vocabulary sizes I think it's not an important use case.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do you mean test for correctness or performance? These cases are the ones used in eval mode only.

I also tested the performance with [512,32000], and it drops to 480GB/s (compared to 730GB/s with [32000,512]). There are surely more optimization opportunities, but I don't think it is worth spending more time on this at moment.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I only meant performance. I wrote the code on master in the context of the ggml MNIST example with an input shape of {10, 1000, 1, 1}. In principle, if you have a low number of classes but a large number of datapoints the number of writes should become significant and it would make sense to try and coalesce them (but with the code on master there are likely also issues with tail effects because the number of CUDA blocks is reduced by a factor of 32). In the first place, I should have written code with a use case like 256000, 128, 1, 1 in mind since that is going to be relevant for llama.cpp.


test_cases.emplace_back(new test_count_equal());

for (int ne3 : {1, 3}) { // CUDA backward pass only supports ne3 == 1
Expand Down Expand Up @@ -3831,6 +3856,10 @@ static std::vector<std::unique_ptr<test_case>> make_test_cases_perf() {
test_cases.emplace_back(new test_soft_max(GGML_TYPE_F32, {64, 64, 20, 1}, false, 1.0f, 0.0f));
test_cases.emplace_back(new test_soft_max(GGML_TYPE_F32, {77, 64, 20, 1}, false, 1.0f, 0.0f));

test_cases.emplace_back(new test_argmax(GGML_TYPE_F32, {32, 10, 1, 1}));
test_cases.emplace_back(new test_argmax(GGML_TYPE_F32, {1024, 10, 1, 1}));
test_cases.emplace_back(new test_argmax(GGML_TYPE_F32, {32000, 512, 1, 1}));

for (int bs : {1, 512}) {
for (ggml_type type_a : all_types) {
for (ggml_type type_b : {GGML_TYPE_F32}) {
Expand Down
Loading