Skip to content

Commit fd5dcc5

Browse files
authored
Optimize GeGLU layer in Gemma (#2975)
1 parent 93dc5a2 commit fd5dcc5

File tree

6 files changed

+108
-77
lines changed

6 files changed

+108
-77
lines changed

csrc/activation_kernels.cu

Lines changed: 48 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -2,52 +2,75 @@
22
#include <torch/extension.h>
33
#include <c10/cuda/CUDAGuard.h>
44

5+
#include <cmath>
6+
57
#include "cuda_compat.h"
68
#include "dispatch_utils.h"
79

810
namespace vllm {
911

10-
template<typename T>
11-
__device__ __forceinline__ T silu(const T& x) {
12-
// x * sigmoid(x)
13-
return (T) (((float) x) / (1.0f + expf((float) -x)));
14-
}
15-
16-
template<typename scalar_t>
17-
__global__ void silu_and_mul_kernel(
12+
// Activation and gating kernel template.
13+
template<typename scalar_t, scalar_t (*ACT_FN)(const scalar_t&)>
14+
__global__ void act_and_mul_kernel(
1815
scalar_t* __restrict__ out, // [..., d]
1916
const scalar_t* __restrict__ input, // [..., 2, d]
2017
const int d) {
2118
const int64_t token_idx = blockIdx.x;
2219
for (int64_t idx = threadIdx.x; idx < d; idx += blockDim.x) {
2320
const scalar_t x = VLLM_LDG(&input[token_idx * 2 * d + idx]);
2421
const scalar_t y = VLLM_LDG(&input[token_idx * 2 * d + d + idx]);
25-
out[token_idx * d + idx] = silu(x) * y;
22+
out[token_idx * d + idx] = ACT_FN(x) * y;
2623
}
2724
}
2825

26+
template<typename T>
27+
__device__ __forceinline__ T silu_kernel(const T& x) {
28+
// x * sigmoid(x)
29+
return (T) (((float) x) / (1.0f + expf((float) -x)));
30+
}
31+
32+
template<typename T>
33+
__device__ __forceinline__ T gelu_kernel(const T& x) {
34+
// Equivalent to PyTorch GELU with 'none' approximation.
35+
// Refer to:
36+
// https://github.com/pytorch/pytorch/blob/8ac9b20d4b090c213799e81acf48a55ea8d437d6/aten/src/ATen/native/cuda/ActivationGeluKernel.cu#L38
37+
const float f = (float) x;
38+
constexpr float ALPHA = M_SQRT1_2;
39+
return (T) (f * 0.5f * (1.0f + ::erf(f * ALPHA)));
40+
}
41+
2942
} // namespace vllm
3043

44+
// Launch activation and gating kernel.
45+
#define LAUNCH_ACTIVATION_GATE_KERNEL(KERNEL) \
46+
int d = input.size(-1) / 2; \
47+
int64_t num_tokens = input.numel() / input.size(-1); \
48+
dim3 grid(num_tokens); \
49+
dim3 block(std::min(d, 1024)); \
50+
const at::cuda::OptionalCUDAGuard device_guard(device_of(input)); \
51+
const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); \
52+
VLLM_DISPATCH_FLOATING_TYPES( \
53+
input.scalar_type(), \
54+
"act_and_mul_kernel", \
55+
[&] { \
56+
vllm::act_and_mul_kernel<scalar_t, KERNEL<scalar_t>><<<grid, block, 0, stream>>>( \
57+
out.data_ptr<scalar_t>(), \
58+
input.data_ptr<scalar_t>(), \
59+
d); \
60+
});
61+
3162
void silu_and_mul(
3263
torch::Tensor& out, // [..., d]
3364
torch::Tensor& input) // [..., 2 * d]
3465
{
35-
int64_t num_tokens = input.numel() / input.size(-1);
36-
int d = input.size(-1) / 2;
37-
38-
dim3 grid(num_tokens);
39-
dim3 block(std::min(d, 1024));
40-
const at::cuda::OptionalCUDAGuard device_guard(device_of(input));
41-
const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
42-
VLLM_DISPATCH_FLOATING_TYPES(
43-
input.scalar_type(),
44-
"silu_and_mul_kernel",
45-
[&] {
46-
vllm::silu_and_mul_kernel<scalar_t><<<grid, block, 0, stream>>>(
47-
out.data_ptr<scalar_t>(),
48-
input.data_ptr<scalar_t>(),
49-
d);
50-
});
66+
LAUNCH_ACTIVATION_GATE_KERNEL(vllm::silu_kernel);
67+
}
68+
69+
void gelu_and_mul(
70+
torch::Tensor& out, // [..., d]
71+
torch::Tensor& input) // [..., 2 * d]
72+
{
73+
LAUNCH_ACTIVATION_GATE_KERNEL(vllm::gelu_kernel);
5174
}
5275

5376
namespace vllm {

csrc/ops.h

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -57,6 +57,10 @@ void silu_and_mul(
5757
torch::Tensor& out,
5858
torch::Tensor& input);
5959

60+
void gelu_and_mul(
61+
torch::Tensor& out,
62+
torch::Tensor& input);
63+
6064
void gelu_new(
6165
torch::Tensor& out,
6266
torch::Tensor& input);

csrc/pybind.cpp

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,10 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
2222
"silu_and_mul",
2323
&silu_and_mul,
2424
"Activation function used in SwiGLU.");
25+
ops.def(
26+
"gelu_and_mul",
27+
&gelu_and_mul,
28+
"Activation function used in GeGLU.");
2529
ops.def(
2630
"gelu_new",
2731
&gelu_new,

tests/kernels/test_activation.py

Lines changed: 15 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,10 @@
1+
from typing import Type
2+
13
import pytest
24
import torch
35

4-
from vllm.model_executor.layers.activation import FastGELU, NewGELU, SiluAndMul
6+
from vllm.model_executor.layers.activation import (FastGELU, GeluAndMul,
7+
NewGELU, SiluAndMul)
58
from allclose_default import get_default_atol, get_default_rtol
69

710
DTYPES = [torch.half, torch.bfloat16, torch.float]
@@ -13,13 +16,15 @@
1316
]
1417

1518

19+
@pytest.mark.parametrize("activation", [SiluAndMul, GeluAndMul])
1620
@pytest.mark.parametrize("num_tokens", NUM_TOKENS)
1721
@pytest.mark.parametrize("d", D)
1822
@pytest.mark.parametrize("dtype", DTYPES)
1923
@pytest.mark.parametrize("seed", SEEDS)
2024
@pytest.mark.parametrize("device", CUDA_DEVICES)
2125
@torch.inference_mode()
22-
def test_silu_and_mul(
26+
def test_act_and_mul(
27+
activation: Type[torch.nn.Module],
2328
num_tokens: int,
2429
d: int,
2530
dtype: torch.dtype,
@@ -31,48 +36,23 @@ def test_silu_and_mul(
3136
torch.cuda.manual_seed(seed)
3237
torch.set_default_device(device)
3338
x = torch.randn(num_tokens, 2 * d, dtype=dtype)
34-
layer = SiluAndMul()
39+
layer = activation()
3540
out = layer(x)
3641
ref_out = layer._forward(x)
37-
assert torch.allclose(out,
38-
ref_out,
39-
atol=get_default_atol(out),
40-
rtol=get_default_rtol(out))
42+
# The SiLU and GELU implementations are equivalent to the native PyTorch
43+
# implementations, so we can do exact comparison.
44+
assert torch.allclose(out, ref_out, atol=0.0, rtol=0.0)
4145

4246

47+
@pytest.mark.parametrize("activation", [FastGELU, NewGELU])
4348
@pytest.mark.parametrize("num_tokens", NUM_TOKENS)
4449
@pytest.mark.parametrize("d", D)
4550
@pytest.mark.parametrize("dtype", DTYPES)
4651
@pytest.mark.parametrize("seed", SEEDS)
4752
@pytest.mark.parametrize("device", CUDA_DEVICES)
4853
@torch.inference_mode()
49-
def test_gelu_new(
50-
num_tokens: int,
51-
d: int,
52-
dtype: torch.dtype,
53-
seed: int,
54-
device: str,
55-
) -> None:
56-
torch.random.manual_seed(seed)
57-
if torch.cuda.is_available():
58-
torch.cuda.manual_seed(seed)
59-
torch.set_default_device(device)
60-
x = torch.randn(num_tokens, d, dtype=dtype)
61-
layer = NewGELU()
62-
out = layer(x)
63-
ref_out = layer._forward(x)
64-
assert torch.allclose(out,
65-
ref_out,
66-
atol=get_default_atol(out),
67-
rtol=get_default_rtol(out))
68-
69-
70-
@pytest.mark.parametrize("num_tokens", NUM_TOKENS)
71-
@pytest.mark.parametrize("d", D)
72-
@pytest.mark.parametrize("dtype", DTYPES)
73-
@pytest.mark.parametrize("seed", SEEDS)
74-
@pytest.mark.parametrize("device", CUDA_DEVICES)
75-
def test_gelu_fast(
54+
def test_activation(
55+
activation: Type[torch.nn.Module],
7656
num_tokens: int,
7757
d: int,
7858
dtype: torch.dtype,
@@ -84,7 +64,7 @@ def test_gelu_fast(
8464
torch.cuda.manual_seed(seed)
8565
torch.set_default_device(device)
8666
x = torch.randn(num_tokens, d, dtype=dtype)
87-
layer = FastGELU()
67+
layer = activation()
8868
out = layer(x)
8969
ref_out = layer._forward(x)
9070
assert torch.allclose(out,

vllm/model_executor/layers/activation.py

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,29 @@ def forward(self, x: torch.Tensor) -> torch.Tensor:
3737
return out
3838

3939

40+
class GeluAndMul(nn.Module):
41+
"""An activation function for GeGLU.
42+
43+
The function computes x -> GELU(x[:d]) * x[d:] where d = x.shape[-1] // 2.
44+
45+
Shapes:
46+
x: (batch_size, seq_len, 2 * d) or (num_tokens, 2 * d)
47+
return: (batch_size, seq_len, d) or (num_tokens, d)
48+
"""
49+
50+
def _forward(self, x: torch.Tensor) -> torch.Tensor:
51+
"""PyTorch-native implementation equivalent to forward()."""
52+
d = x.shape[-1] // 2
53+
return F.gelu(x[..., :d]) * x[..., d:]
54+
55+
def forward(self, x: torch.Tensor) -> torch.Tensor:
56+
d = x.shape[-1] // 2
57+
output_shape = (x.shape[:-1] + (d, ))
58+
out = torch.empty(output_shape, dtype=x.dtype, device=x.device)
59+
ops.gelu_and_mul(out, x)
60+
return out
61+
62+
4063
class NewGELU(nn.Module):
4164

4265
def _forward(self, x: torch.Tensor) -> torch.Tensor:

vllm/model_executor/models/gemma.py

Lines changed: 14 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -21,10 +21,11 @@
2121
from transformers import GemmaConfig
2222

2323
from vllm.model_executor.input_metadata import InputMetadata
24+
from vllm.model_executor.layers.activation import GeluAndMul
2425
from vllm.model_executor.layers.attention import PagedAttention
2526
from vllm.model_executor.layers.layernorm import RMSNorm
26-
from vllm.model_executor.layers.linear import (ColumnParallelLinear,
27-
LinearMethodBase,
27+
from vllm.model_executor.layers.linear import (LinearMethodBase,
28+
MergedColumnParallelLinear,
2829
QKVParallelLinear,
2930
RowParallelLinear)
3031
from vllm.model_executor.layers.rotary_embedding import get_rope
@@ -50,27 +51,21 @@ def __init__(
5051
linear_method: Optional[LinearMethodBase] = None,
5152
) -> None:
5253
super().__init__()
53-
self.gate_proj = ColumnParallelLinear(hidden_size,
54-
intermediate_size,
55-
bias=False,
56-
linear_method=linear_method)
57-
self.up_proj = ColumnParallelLinear(hidden_size,
58-
intermediate_size,
59-
bias=False,
60-
linear_method=linear_method)
54+
self.gate_up_proj = MergedColumnParallelLinear(
55+
hidden_size, [intermediate_size] * 2,
56+
bias=False,
57+
linear_method=linear_method)
6158
self.down_proj = RowParallelLinear(intermediate_size,
6259
hidden_size,
6360
bias=False,
6461
linear_method=linear_method)
65-
self.act_fn = nn.GELU()
62+
self.act_fn = GeluAndMul()
6663

6764
def forward(self, x):
68-
gate, _ = self.gate_proj(x)
69-
gate = self.act_fn(gate)
70-
up, _ = self.up_proj(x)
71-
fuse = gate * up
72-
outputs, _ = self.down_proj(fuse)
73-
return outputs
65+
gate_up, _ = self.gate_up_proj(x)
66+
x = self.act_fn(gate_up)
67+
x, _ = self.down_proj(x)
68+
return x
7469

7570

7671
class GemmaAttention(nn.Module):
@@ -294,6 +289,8 @@ def load_weights(self,
294289
("qkv_proj", "q_proj", "q"),
295290
("qkv_proj", "k_proj", "k"),
296291
("qkv_proj", "v_proj", "v"),
292+
("gate_up_proj", "gate_proj", 0),
293+
("gate_up_proj", "up_proj", 1),
297294
]
298295
params_dict = dict(self.named_parameters())
299296
loaded_params = set()

0 commit comments

Comments
 (0)