1
+ from typing import Type
2
+
1
3
import pytest
2
4
import torch
3
5
4
- from vllm .model_executor .layers .activation import FastGELU , NewGELU , SiluAndMul
6
+ from vllm .model_executor .layers .activation import (FastGELU , GeluAndMul ,
7
+ NewGELU , SiluAndMul )
5
8
from allclose_default import get_default_atol , get_default_rtol
6
9
7
10
DTYPES = [torch .half , torch .bfloat16 , torch .float ]
13
16
]
14
17
15
18
19
+ @pytest .mark .parametrize ("activation" , [SiluAndMul , GeluAndMul ])
16
20
@pytest .mark .parametrize ("num_tokens" , NUM_TOKENS )
17
21
@pytest .mark .parametrize ("d" , D )
18
22
@pytest .mark .parametrize ("dtype" , DTYPES )
19
23
@pytest .mark .parametrize ("seed" , SEEDS )
20
24
@pytest .mark .parametrize ("device" , CUDA_DEVICES )
21
25
@torch .inference_mode ()
22
- def test_silu_and_mul (
26
+ def test_act_and_mul (
27
+ activation : Type [torch .nn .Module ],
23
28
num_tokens : int ,
24
29
d : int ,
25
30
dtype : torch .dtype ,
@@ -31,48 +36,23 @@ def test_silu_and_mul(
31
36
torch .cuda .manual_seed (seed )
32
37
torch .set_default_device (device )
33
38
x = torch .randn (num_tokens , 2 * d , dtype = dtype )
34
- layer = SiluAndMul ()
39
+ layer = activation ()
35
40
out = layer (x )
36
41
ref_out = layer ._forward (x )
37
- assert torch .allclose (out ,
38
- ref_out ,
39
- atol = get_default_atol (out ),
40
- rtol = get_default_rtol (out ))
42
+ # The SiLU and GELU implementations are equivalent to the native PyTorch
43
+ # implementations, so we can do exact comparison.
44
+ assert torch .allclose (out , ref_out , atol = 0.0 , rtol = 0.0 )
41
45
42
46
47
+ @pytest .mark .parametrize ("activation" , [FastGELU , NewGELU ])
43
48
@pytest .mark .parametrize ("num_tokens" , NUM_TOKENS )
44
49
@pytest .mark .parametrize ("d" , D )
45
50
@pytest .mark .parametrize ("dtype" , DTYPES )
46
51
@pytest .mark .parametrize ("seed" , SEEDS )
47
52
@pytest .mark .parametrize ("device" , CUDA_DEVICES )
48
53
@torch .inference_mode ()
49
- def test_gelu_new (
50
- num_tokens : int ,
51
- d : int ,
52
- dtype : torch .dtype ,
53
- seed : int ,
54
- device : str ,
55
- ) -> None :
56
- torch .random .manual_seed (seed )
57
- if torch .cuda .is_available ():
58
- torch .cuda .manual_seed (seed )
59
- torch .set_default_device (device )
60
- x = torch .randn (num_tokens , d , dtype = dtype )
61
- layer = NewGELU ()
62
- out = layer (x )
63
- ref_out = layer ._forward (x )
64
- assert torch .allclose (out ,
65
- ref_out ,
66
- atol = get_default_atol (out ),
67
- rtol = get_default_rtol (out ))
68
-
69
-
70
- @pytest .mark .parametrize ("num_tokens" , NUM_TOKENS )
71
- @pytest .mark .parametrize ("d" , D )
72
- @pytest .mark .parametrize ("dtype" , DTYPES )
73
- @pytest .mark .parametrize ("seed" , SEEDS )
74
- @pytest .mark .parametrize ("device" , CUDA_DEVICES )
75
- def test_gelu_fast (
54
+ def test_activation (
55
+ activation : Type [torch .nn .Module ],
76
56
num_tokens : int ,
77
57
d : int ,
78
58
dtype : torch .dtype ,
@@ -84,7 +64,7 @@ def test_gelu_fast(
84
64
torch .cuda .manual_seed (seed )
85
65
torch .set_default_device (device )
86
66
x = torch .randn (num_tokens , d , dtype = dtype )
87
- layer = FastGELU ()
67
+ layer = activation ()
88
68
out = layer (x )
89
69
ref_out = layer ._forward (x )
90
70
assert torch .allclose (out ,
0 commit comments