Skip to content

Commit

Permalink
Cache weight for large batch inference for full bf16 and WOQ lowp-mod…
Browse files Browse the repository at this point in the history
…e=bf16 (#2898)

* Keep bf16 weight for WOQ first token

* Keep first token weight for woq int4 and full bf16

* Revert unnecessary changes

* fix clang-format issue

* Fix UT failures

* Fix concat linear

* Fix UT failures

* fix lint issue

* Cache extra weight at runtime instead of ahead-of-time

* fix lint
  • Loading branch information
Xia-Weiwen authored Jun 18, 2024
1 parent 2795053 commit 52f8c48
Show file tree
Hide file tree
Showing 30 changed files with 968 additions and 283 deletions.
30 changes: 29 additions & 1 deletion csrc/cpu/aten/TPPGEMM.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,17 @@ IPEX_DEFINE_DISPATCH(tpp_linear_relu_kernel_stub);
IPEX_DEFINE_DISPATCH(tpp_linear_add_kernel_stub);
IPEX_DEFINE_DISPATCH(tpp_linear_mul_kernel_stub);
IPEX_DEFINE_DISPATCH(tpp_linear_add_add_kernel_stub);
IPEX_DEFINE_DISPATCH(tpp_gelu_tanh_bf16_kernel_stub);

void tpp_gelu_tanh_bf16_forward_cpu(
at::BFloat16* in,
at::BFloat16* out,
int M,
int N,
int ldi,
int ldo) {
tpp_gelu_tanh_bf16_kernel_stub(kCPU, in, out, M, N, ldi, ldo);
}

at::Tensor tpp_linear_nobias_forward_cpu(
const at::Tensor& t_in,
Expand All @@ -36,7 +47,15 @@ at::Tensor tpp_linear_gelu_forward_cpu(
const at::Tensor& t_wt,
const at::Tensor& t_bias,
c10::optional<int64_t> out_features) {
return tpp_linear_gelu_kernel_stub(kCPU, t_in, t_wt, t_bias);
return tpp_linear_gelu_kernel_stub(kCPU, t_in, t_wt, t_bias, "none");
}

at::Tensor tpp_linear_gelu_tanh_forward_cpu(
const at::Tensor& t_in,
const at::Tensor& t_wt,
const at::Tensor& t_bias,
c10::optional<int64_t> out_features) {
return tpp_linear_gelu_kernel_stub(kCPU, t_in, t_wt, t_bias, "tanh");
}

at::Tensor tpp_fused_gate_up_proj_forward_cpu(
Expand Down Expand Up @@ -129,6 +148,15 @@ TORCH_LIBRARY_FRAGMENT(torch_ipex, m) {
torch_ipex::cpu::tpp_linear_gelu_forward_cpu);
}

TORCH_LIBRARY_FRAGMENT(torch_ipex, m) {
m.def(
"tpp_linear_gelu_tanh(Tensor t_in, Tensor t_wt, Tensor t_bias, int? out_features=None)-> Tensor out");
m.impl(
"tpp_linear_gelu_tanh",
c10::DispatchKey::CPU,
torch_ipex::cpu::tpp_linear_gelu_tanh_forward_cpu);
}

TORCH_LIBRARY_FRAGMENT(torch_ipex, m) {
m.def(
"tpp_fused_gate_up_proj(Tensor t_in, Tensor t_wt_gate, Tensor t_bias_gate, Tensor t_wt_up, Tensor t_bias_up,int? out_features=None)-> Tensor out");
Expand Down
27 changes: 25 additions & 2 deletions csrc/cpu/aten/TPPGEMM.h
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,12 @@ at::Tensor tpp_linear_gelu_forward_cpu(
const at::Tensor& t_bias,
c10::optional<int64_t> out_features);

at::Tensor tpp_linear_gelu_tanh_forward_cpu(
const at::Tensor& t_in,
const at::Tensor& t_wt,
const at::Tensor& t_bias,
c10::optional<int64_t> out_features);

at::Tensor tpp_fused_gate_up_proj_forward_cpu(
const at::Tensor& t_in,
const at::Tensor& t_wt_gate,
Expand Down Expand Up @@ -67,14 +73,25 @@ at::Tensor tpp_linear_add_add_forward_cpu(
double scale,
c10::optional<int64_t> out_features);

void tpp_gelu_tanh_bf16_forward_cpu(
at::BFloat16* in,
at::BFloat16* out,
int M,
int N,
int ldi,
int ldo);

using tpp_linear_nobias_impl_fn =
at::Tensor (*)(const at::Tensor&, const at::Tensor&);

using tpp_linear_bias_kernel_impl_fn =
at::Tensor (*)(const at::Tensor&, const at::Tensor&, const at::Tensor&);

using tpp_linear_gelu_kernel_impl_fn =
at::Tensor (*)(const at::Tensor&, const at::Tensor&, const at::Tensor&);
using tpp_linear_gelu_kernel_impl_fn = at::Tensor (*)(
const at::Tensor&,
const at::Tensor&,
const at::Tensor&,
const c10::string_view&);

using tpp_fused_gate_up_proj_kernel_impl_fn = at::Tensor (*)(
const at::Tensor&,
Expand Down Expand Up @@ -110,6 +127,9 @@ using tpp_linear_add_add_kernel_impl_fn = at::Tensor (*)(
const at::Tensor&,
double);

using tpp_gelu_tanh_bf16_kernel_impl_fn =
void (*)(at::BFloat16*, at::BFloat16*, int, int, int, int);

IPEX_DECLARE_DISPATCH(tpp_linear_nobias_impl_fn, tpp_linear_nobias_kernel_stub);
IPEX_DECLARE_DISPATCH(
tpp_linear_bias_kernel_impl_fn,
Expand All @@ -135,6 +155,9 @@ IPEX_DECLARE_DISPATCH(
IPEX_DECLARE_DISPATCH(
tpp_linear_add_add_kernel_impl_fn,
tpp_linear_add_add_kernel_stub);
IPEX_DECLARE_DISPATCH(
tpp_gelu_tanh_bf16_kernel_impl_fn,
tpp_gelu_tanh_bf16_kernel_stub);

} // namespace cpu
} // namespace torch_ipex
Expand Down
72 changes: 69 additions & 3 deletions csrc/cpu/aten/kernels/TPPGEMMKrnl.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,13 @@ at::Tensor tpp_linear_nobias_kernel_impl(
at::Tensor tpp_linear_gelu_kernel_impl(
const at::Tensor& t_in,
const at::Tensor& t_wt,
const at::Tensor& t_bias) {
const at::Tensor& t_bias,
const c10::string_view& algorithm) {
AT_ASSERT(
algorithm == "none" || algorithm == "tanh",
"tpp_linear_gelu: Invalid gelu algorithm %s\n",
algorithm);

auto sizes = t_in.sizes().vec();
auto wt_sizes = t_wt.sizes();
sizes[2] = wt_sizes[0] * wt_sizes[3];
Expand All @@ -74,9 +80,18 @@ at::Tensor tpp_linear_gelu_kernel_impl(

auto dt = t_wt.dtype();
if (dt == at::kFloat) {
torch_ipex::tpp::tpp_linear_gelu<float>(t_in, t_wt, t_bias, t_out);
if (algorithm == "none") {
torch_ipex::tpp::tpp_linear_gelu<float>(t_in, t_wt, t_bias, t_out);
} else { // tanh
torch_ipex::tpp::tpp_linear_gelu_tanh<float>(t_in, t_wt, t_bias, t_out);
}
} else if (dt == at::kBFloat16) {
torch_ipex::tpp::tpp_linear_gelu<at::BFloat16>(t_in, t_wt, t_bias, t_out);
if (algorithm == "none") {
torch_ipex::tpp::tpp_linear_gelu<at::BFloat16>(t_in, t_wt, t_bias, t_out);
} else { // tanh
torch_ipex::tpp::tpp_linear_gelu_tanh<at::BFloat16>(
t_in, t_wt, t_bias, t_out);
}
} else {
AT_ASSERT(
0,
Expand Down Expand Up @@ -240,6 +255,54 @@ at::Tensor tpp_linear_mul_kernel_impl(
return t_out;
}

void tpp_gelu_tanh_bf16_kernel_impl(
at::BFloat16* in,
at::BFloat16* out,
int M,
int N,
int ldi,
int ldo) {
#ifdef CPU_CAPABILITY_AVX512
const __m512 c1 = _mm512_set1_ps((float)0.7978846);
const __m512 c2 = _mm512_set1_ps((float)0.0356814);
const __m512 c_half = _mm512_set1_ps((float)0.5);
for (int j = 0; j < M; j++) {
int i;
for (i = 0; i < ALIGNDOWN(N, 16); i += 16) {
auto vin = torch_ipex::tpp::_mm512_loadu_ps_auto(&in[j * ldi + i]);
__m512 x_half = _mm512_mul_ps(vin, c_half);
__m512 x_sq = _mm512_mul_ps(vin, vin);
__m512 poly_x1 = _mm512_mul_ps(vin, _mm512_fmadd_ps(x_sq, c2, c1));
__m512 tanh_poly_x = LIBXSMM_INTRINSICS_MM512_TANH_PS_MINIMAX3(poly_x1);
__m512 vout = _mm512_fmadd_ps(tanh_poly_x, x_half, x_half);
torch_ipex::tpp::_mm512_storeu_ps_auto(&out[j * ldo + i], vout);
}
if (i < N) {
int rem = N - i;
__mmask16 mask = (1 << rem) - 1;
auto vin =
torch_ipex::tpp::_mm512_maskz_loadu_ps_auto(mask, &in[j * ldi + i]);
__m512 x_half = _mm512_mul_ps(vin, c_half);
__m512 x_sq = _mm512_mul_ps(vin, vin);
__m512 poly_x1 = _mm512_mul_ps(vin, _mm512_fmadd_ps(x_sq, c2, c1));
__m512 tanh_poly_x = LIBXSMM_INTRINSICS_MM512_TANH_PS_MINIMAX3(poly_x1);
__m512 vout = _mm512_fmadd_ps(tanh_poly_x, x_half, x_half);
torch_ipex::tpp::_mm512_mask_storeu_ps_auto(
&out[j * ldo + i], mask, vout);
}
}
#else
for (int j = 0; j < M; j++) {
for (int i = 0; i < N; i++) {
float x = in[j * ldi + i];
out[j * ldo + i] =
((tanh(sqrt(2 / M_PI) * (x + 0.044715 * std::pow(x, 3)))) + 1) * x *
0.5;
}
}
#endif
}

} // namespace

IPEX_REGISTER_DISPATCH(
Expand All @@ -265,6 +328,9 @@ IPEX_REGISTER_DISPATCH(tpp_linear_add_kernel_stub, &tpp_linear_add_kernel_impl);
IPEX_REGISTER_DISPATCH(
tpp_linear_add_add_kernel_stub,
&tpp_linear_add_add_kernel_impl);
IPEX_REGISTER_DISPATCH(
tpp_gelu_tanh_bf16_kernel_stub,
&tpp_gelu_tanh_bf16_kernel_impl);
} // namespace cpu
} // namespace torch_ipex
#endif
Loading

0 comments on commit 52f8c48

Please sign in to comment.