1
1
"""Token blocks."""
2
- from typing import List
2
+ import weakref
3
+ from collections import defaultdict
4
+ from typing import Dict , List
3
5
4
6
from vllm .utils import Device
5
7
6
8
_BLANK_TOKEN_ID = - 1
7
9
8
10
DEFAULT_LAST_ACCESSED_TIME = - 1
9
11
12
+ TokensBlock = List [int ]
13
+
14
+
15
+ class BlockPool :
16
+ """A pool of physical blocks.
17
+ When requests come, we create a lot of logical blocks;
18
+ when requests are done, we destroy a lot of logical blocks.
19
+ It turns out that creating and destroying logical blocks can be expensive,
20
+ especially for the `token_ids` field, which is a list of integers.
21
+ To avoid this overhead, we use a pool to manage the logical blocks.
22
+ When an old request is done and a new request comes, we can reuse the
23
+ logical blocks from the old request to feed the new request.
24
+ """
25
+
26
+ def __init__ (self ) -> None :
27
+ # block size to list of token blocks
28
+ self .pool : Dict [int , List [TokensBlock ]] = defaultdict (list )
29
+
30
+ def alloc_block (self , block_size : int ) -> TokensBlock :
31
+ if block_size in self .pool and self .pool [block_size ]:
32
+ return self .pool [block_size ].pop ()
33
+ return [_BLANK_TOKEN_ID ] * block_size
34
+
35
+ def del_block (self , block : TokensBlock ) -> None :
36
+ self .pool [len (block )].append (block )
37
+
38
+
39
+ _BLOCK_POOL = BlockPool ()
40
+
10
41
11
42
class LogicalTokenBlock :
12
43
"""A block that stores a contiguous chunk of tokens from left to right.
@@ -23,7 +54,13 @@ def __init__(
23
54
self .block_number = block_number
24
55
self .block_size = block_size
25
56
26
- self .token_ids = [_BLANK_TOKEN_ID ] * block_size
57
+ self .token_ids = _BLOCK_POOL .alloc_block (block_size )
58
+ # this finalizer is used to return the block to the pool when the object is deleted # noqa
59
+ # NOTE: don't use __del__ because it cannot guarantee the order of finalization, # noqa
60
+ # i.e. `self.token_ids` may be deleted before `self`, and we lose
61
+ # the opportunity to return the block to the pool
62
+ self ._finalizer = weakref .finalize (self , _BLOCK_POOL .del_block ,
63
+ self .token_ids )
27
64
self .num_tokens = 0
28
65
29
66
def is_empty (self ) -> bool :
0 commit comments