diff --git a/ggml-metal.metal b/ggml-metal.metal index e152cc53c0b97..9a79f815f3a72 100644 --- a/ggml-metal.metal +++ b/ggml-metal.metal @@ -203,9 +203,9 @@ kernel void kernel_soft_max( device float * pdst = dst + i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00; // parallel max - float lmax = (tpitg < ne00) ? (psrc0[tpitg]*scale + (pmask ? pmask[tpitg] : 0.0f)) : -INFINITY; + float lmax = -INFINITY; - for (int i00 = tpitg + ntg; i00 < ne00; i00 += ntg) { + for (int i00 = tpitg; i00 < ne00; i00 += ntg) { lmax = MAX(lmax, psrc0[i00]*scale + (pmask ? pmask[i00] : 0.0f)); } @@ -284,9 +284,9 @@ kernel void kernel_soft_max_4( device float4 * pdst4 = (device float4 *)(dst + i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00); // parallel max - float4 lmax4 = tpitg < ne00/4 ? (psrc4[tpitg]*scale + (pmask ? pmask[tpitg] : 0.0f)) : -INFINITY; + float4 lmax4 = -INFINITY; - for (int i00 = tpitg + ntg; i00 < ne00/4; i00 += ntg) { + for (int i00 = tpitg; i00 < ne00/4; i00 += ntg) { lmax4 = fmax(lmax4, psrc4[i00]*scale + (pmask ? pmask[i00] : 0.0f)); }