Skip to content

Commit

Permalink
metal : simplify soft max kernel
Browse files Browse the repository at this point in the history
ggml-ci
  • Loading branch information
ggerganov committed Dec 1, 2023
1 parent c4db592 commit d9c8fa3
Showing 1 changed file with 4 additions and 4 deletions.
8 changes: 4 additions & 4 deletions ggml-metal.metal
Original file line number Diff line number Diff line change
Expand Up @@ -203,9 +203,9 @@ kernel void kernel_soft_max(
device float * pdst = dst + i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00;

// parallel max
float lmax = (tpitg < ne00) ? (psrc0[tpitg]*scale + (pmask ? pmask[tpitg] : 0.0f)) : -INFINITY;
float lmax = -INFINITY;

for (int i00 = tpitg + ntg; i00 < ne00; i00 += ntg) {
for (int i00 = tpitg; i00 < ne00; i00 += ntg) {
lmax = MAX(lmax, psrc0[i00]*scale + (pmask ? pmask[i00] : 0.0f));
}

Expand Down Expand Up @@ -284,9 +284,9 @@ kernel void kernel_soft_max_4(
device float4 * pdst4 = (device float4 *)(dst + i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00);

// parallel max
float4 lmax4 = tpitg < ne00/4 ? (psrc4[tpitg]*scale + (pmask ? pmask[tpitg] : 0.0f)) : -INFINITY;
float4 lmax4 = -INFINITY;

for (int i00 = tpitg + ntg; i00 < ne00/4; i00 += ntg) {
for (int i00 = tpitg; i00 < ne00/4; i00 += ntg) {
lmax4 = fmax(lmax4, psrc4[i00]*scale + (pmask ? pmask[i00] : 0.0f));
}

Expand Down

0 comments on commit d9c8fa3

Please sign in to comment.