Skip to content

Commit

Permalink
cherry-pick flash attention exp_u20 (#2480)
Browse files Browse the repository at this point in the history
  • Loading branch information
Valentine233 authored Jan 16, 2024
1 parent eb7279b commit df2387e
Showing 1 changed file with 75 additions and 1 deletion.
76 changes: 75 additions & 1 deletion csrc/cpu/aten/kernels/FlashAttentionKrnl.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -125,6 +125,80 @@ inline c10::SymFloat calculate_scale(
return c10::SymFloat(softmax_scale);
}

template <typename scalar_t>
inline Vectorized<scalar_t> exp_u20(Vectorized<scalar_t> data) {
return data.exp();
}
#if defined(CPU_CAPABILITY_AVX512)
// To implement exp_u20 here is faster than calling from add_softmax.h or PT
// vec512_float.h
inline Vectorized<float> exp_u20(Vectorized<float> data) {
__m512 values = __m512(data);
// A faster version of exp with ULP=20
static __m512 vec_factorial_1 =
_mm512_set1_ps(0.999999701f); // 1/factorial(1)
static __m512 vec_factorial_2 =
_mm512_set1_ps(0.499991506f); // 1/factorial(2)
static __m512 vec_factorial_3 =
_mm512_set1_ps(0.166676521f); // 1/factorial(3)
static __m512 vec_factorial_4 =
_mm512_set1_ps(0.0418978221f); // 1/factorial(4)
static __m512 vec_factorial_5 =
_mm512_set1_ps(0.00828929059f); // 1/factorial(5)
static __m512 vec_exp_log2ef =
(__m512)_mm512_set1_epi32(0x3fb8aa3b); // log2(e)
static __m512 vec_half = _mm512_set1_ps(0.5f);
static __m512 vec_one = _mm512_set1_ps(1.f);
static __m512 vec_zero = _mm512_set1_ps(0.f);
static __m512 vec_two = _mm512_set1_ps(2.f);
static __m512 vec_ln2f = (__m512)_mm512_set1_epi32(0x3f317218); // ln(2)
static __m512 vec_ln_flt_min = (__m512)_mm512_set1_epi32(0xc2aeac50);
static __m512 vec_ln_flt_max = (__m512)_mm512_set1_epi32(0x42b17218);
static __m512i vec_127 = _mm512_set1_epi32(0x0000007f);
static int n_mantissa_bits = 23;

// exp(x) =
// = exp(n * ln(2) + r) // divide x by ln(2) and get quot and rem
// = 2^n * exp(r) // simplify the exp(n*ln(2)) expression

auto less_ln_flt_min_mask =
_mm512_cmp_ps_mask(values, vec_ln_flt_min, 1 /*_CMP_LT_OS*/);
auto vec_src = _mm512_min_ps(values, vec_ln_flt_max);
vec_src = _mm512_max_ps(vec_src, vec_ln_flt_min);

// fx = floorf(x * log2ef + 0.5)
auto vec_fx = _mm512_fmadd_ps(vec_src, vec_exp_log2ef, vec_half);
auto vec_fx_i = _mm512_cvt_roundps_epi32(
vec_fx, _MM_FROUND_TO_NEG_INF | _MM_FROUND_NO_EXC);
vec_fx = _mm512_cvtepi32_ps(vec_fx_i);

// x = x - fx * ln2
auto vec_exp_poly = _mm512_fnmadd_ps(vec_fx, vec_ln2f, vec_src);

// compute polynomial
auto vec_res =
_mm512_fmadd_ps(vec_exp_poly, vec_factorial_5, vec_factorial_4);
vec_res = _mm512_fmadd_ps(vec_exp_poly, vec_res, vec_factorial_3);
vec_res = _mm512_fmadd_ps(vec_exp_poly, vec_res, vec_factorial_2);
vec_res = _mm512_fmadd_ps(vec_exp_poly, vec_res, vec_factorial_1);
vec_res = _mm512_fmadd_ps(vec_exp_poly, vec_res, vec_one);

// compute 2^(n-1)
auto vec_exp_number = _mm512_sub_ps(vec_fx, vec_one);
auto vec_exp_number_i = _mm512_cvtps_epi32(vec_exp_number);
auto vec_two_pow_n_i = _mm512_add_epi32(vec_exp_number_i, vec_127);
vec_two_pow_n_i = _mm512_slli_epi32(vec_two_pow_n_i, n_mantissa_bits);
auto vec_two_pow_n = (__m512)vec_two_pow_n_i;
vec_two_pow_n =
_mm512_mask_blend_ps(less_ln_flt_min_mask, vec_two_pow_n, vec_zero);

// y = y * 2^n
vec_res = _mm512_mul_ps(vec_res, vec_two_pow_n);
vec_res = _mm512_mul_ps(vec_res, vec_two);
return vec_res;
}
#endif

// 1) out = exp(a - val)
// 2) val = sum(out)
template <typename scalar_t>
Expand All @@ -140,7 +214,7 @@ inline void _exp_reduce_sum_fusion_kernel(
for (long i = 0; i < vec_size * (size / vec_size); i += vec_size) {
auto tmp0 = at::vec::Vectorized<scalar_t>::loadu(a + i);
auto tmp1 = tmp0 - vec_max;
auto tmp2 = tmp1.exp();
auto tmp2 = exp_u20(tmp1);
vec_tmp_sum += tmp2;
_store(out + i, tmp2);
}
Expand Down

0 comments on commit df2387e

Please sign in to comment.