1
- import time
2
- from typing import List
3
1
import asyncio
2
+ from typing import List
3
+
4
+ from transformers import AutoTokenizer
4
5
5
6
from .dynamic_batching .infer_batch import InferBatch
6
7
from .dynamic_batching .io_struct import Batch , Req
9
10
from .dynamic_batching .stats import Stats
10
11
from .tensor_parallel import TPInferEngine
11
12
12
- from transformers import AutoTokenizer
13
13
_FAST_LLAMA_TOKENIZER = "hf-internal-testing/llama-tokenizer"
14
14
15
+
15
16
class DynamicBatchManager :
16
17
def __init__ (
17
18
self ,
18
19
tp_engine : TPInferEngine ,
19
20
max_total_token_num ,
20
21
batch_max_tokens ,
21
22
eos_id ,
23
+ model ,
22
24
log_stats = True ,
23
25
log_stats_interval = 10 ,
24
26
running_batch : Batch = None ,
@@ -30,6 +32,7 @@ def __init__(
30
32
batch_max_tokens : max tokens of one batch, default to (max input + output len) * num_requests
31
33
running_max_req_size : max request size of running batch, equals to MAX_BATCH_SIZE of tp engine
32
34
eos_id : The end token of a seq
35
+ model: the model weight dir path, the app will load config, weights and tokenizer from this dir
33
36
log_stats : whether to log stats
34
37
log_stats_interval : log stats interval
35
38
running_batch : running batch
@@ -45,32 +48,32 @@ def __init__(
45
48
self .eos_id = eos_id
46
49
self .has_wait_tokens = 0
47
50
self .max_wait_tokens = 10
48
-
51
+ self .model = model
52
+
49
53
self .stats_tool = Stats (log_stats , log_stats_interval )
50
54
self .mem_usage_interval = log_stats_interval * 2
55
+ self ._set_tokenizer (tokenizer_name = self .model )
51
56
52
- def add_req (self , prompt_ids : List [int ], sampling_params : SamplingParams , request_id : str ):
57
+ async def add_req (self , request_id , prompt_ids : List [int ], sampling_params : SamplingParams , prompts : str = "" ):
53
58
"""
54
59
Add new request to req queue, during initialization all requests are held in waiting list.
55
60
"""
56
- req = Req (request_id , prompt_ids , sampling_params )
61
+ req = Req (request_id , prompt_ids , sampling_params , prompts )
57
62
self .req_queue .append (req )
58
63
return
59
64
60
- def add_input (self , request_id , sampling_params , input_ids ):
65
+ async def add_input (self , request_id , sampling_params , prompts ):
61
66
"""
62
67
Encode and Add new input to req queue. support one sequence input for now.
63
68
"""
64
- prompt_ids = self .tokenizer .encode (input_ids )
69
+ prompt_ids = self .tokenizer .encode (prompts )
65
70
prompt_len = len (prompt_ids )
66
71
if prompt_len > self .engine .max_input_len :
67
- raise ValueError (
68
- f"the input prompt token len { prompt_len } is too long > { self .engine .max_input_len } "
69
- )
72
+ raise ValueError (f"the input prompt token len { prompt_len } is too long > { self .engine .max_input_len } " )
70
73
sampling_params .stop_sentences_to_token_ids (self .tokenizer )
71
- self .add_req (prompt_ids , sampling_params , request_id )
74
+ self .add_req (request_id , prompt_ids , sampling_params , prompts )
72
75
return
73
-
76
+
74
77
def abort (self , request_id ):
75
78
if self .running_batch is not None :
76
79
for req in self .running_batch .reqs :
@@ -88,10 +91,15 @@ async def loop_for_fwd(self):
88
91
The main loop for a dynamic batching process.
89
92
"""
90
93
counter_count = 0
91
- #self.running_batch is not None or self.req_queue.waiting_req_list
94
+ # self.running_batch is not None or self.req_queue.waiting_req_list
92
95
while True :
93
- async for item in self ._step ():
94
- yield item
96
+ if self .running_batch is not None or self .req_queue .waiting_req_list :
97
+ async for result in self ._step ():
98
+ yield result
99
+ else :
100
+ # need to wait for new requests
101
+ await asyncio .sleep (0.1 )
102
+ continue
95
103
counter_count += 1
96
104
if self .running_batch is not None :
97
105
if counter_count % self .mem_usage_interval == 0 :
@@ -103,30 +111,33 @@ async def loop_for_fwd(self):
103
111
)
104
112
self .stats_tool .print_stats ()
105
113
106
- if self .running_batch is None :
107
- time .sleep (0.1 ) # 10ms
108
-
109
- def _set_tokenizer (self , tokenizer = None , tokenizer_name : str = "" , trust_remote_code : bool = False , use_fast :bool = True ,):
114
+ def _set_tokenizer (
115
+ self , tokenizer = None , tokenizer_name : str = "" , trust_remote_code : bool = False , use_fast : bool = True
116
+ ):
110
117
if tokenizer is not None :
111
- self .tokenizer = tokenizer
118
+ self .tokenizer = tokenizer
112
119
else :
113
120
if "llama" in tokenizer_name .lower () and use_fast == True :
114
121
print (
115
- "For some LLaMA-based models, initializing the fast tokenizer may "
116
- "take a long time. To eliminate the initialization time, consider "
117
- f"using '{ _FAST_LLAMA_TOKENIZER } ' instead of the original "
118
- "tokenizer. This is done automatically in Colossalai." )
119
-
120
- tokenizer_name = _FAST_LLAMA_TOKENIZER
121
-
122
- try :
123
- self .tokenizer = AutoTokenizer .from_pretrained (tokenizer_name , use_fast = use_fast ,trust_remote_code = trust_remote_code )
124
- except TypeError as e :
122
+ "For some LLaMA-based models, initializing the fast tokenizer may "
123
+ "take a long time. To eliminate the initialization time, consider "
124
+ f"using '{ _FAST_LLAMA_TOKENIZER } ' instead of the original "
125
+ "tokenizer. This is done automatically in Colossalai."
126
+ )
127
+
128
+ tokenizer_name = _FAST_LLAMA_TOKENIZER
129
+
130
+ try :
131
+ self .tokenizer = AutoTokenizer .from_pretrained (
132
+ tokenizer_name , use_fast = use_fast , trust_remote_code = trust_remote_code
133
+ )
134
+ except TypeError :
125
135
use_fast = False
126
- self .tokenizer = AutoTokenizer .from_pretrained (tokenizer_name , use_fast = use_fast ,trust_remote_code = trust_remote_code )
127
-
136
+ self .tokenizer = AutoTokenizer .from_pretrained (
137
+ tokenizer_name , use_fast = use_fast , trust_remote_code = trust_remote_code
138
+ )
128
139
129
- def _step (self ):
140
+ async def _step (self ):
130
141
"""
131
142
Logic for handling requests
132
143
"""
@@ -136,33 +147,36 @@ def _step(self):
136
147
if new_batch is not None :
137
148
self .stats_tool .count_prompt_tokens (new_batch )
138
149
self .running_batch = new_batch
139
- yield from self ._prefill_batch (self .running_batch )
150
+ async for item in self ._prefill_batch (self .running_batch ):
151
+ yield item
140
152
self ._filter_runing_batch ()
141
153
self .has_wait_tokens = 0
142
154
return
143
155
144
156
if self .has_wait_tokens < self .max_wait_tokens :
145
157
self .stats_tool .count_output_tokens (self .running_batch )
146
- yield from self ._decode_batch (self .running_batch )
158
+ self ._decode_batch (self .running_batch )
147
159
self ._filter_runing_batch ()
148
160
self .has_wait_tokens += 1
149
161
return
150
162
else :
151
163
new_mini_batch = self .req_queue .generate_new_batch (self .running_batch )
152
164
if new_mini_batch is not None :
153
165
self .stats_tool .count_prompt_tokens (new_mini_batch )
154
- yield from self ._prefill_batch (new_mini_batch )
166
+ async for item in self ._prefill_batch (new_mini_batch ):
167
+ yield item
155
168
if not new_mini_batch .is_clear ():
156
169
self ._merge_batch (self .running_batch , new_mini_batch )
157
170
self .running_batch .merge (new_mini_batch )
158
171
self .has_wait_tokens = 0
159
-
172
+
160
173
else :
161
174
self .stats_tool .count_output_tokens (self .running_batch )
162
- yield from self ._decode_batch (self .running_batch )
175
+ async for item in self ._decode_batch (self .running_batch ):
176
+ yield item
163
177
self ._filter_runing_batch ()
164
178
self .has_wait_tokens += 1
165
-
179
+
166
180
return
167
181
168
182
def _init_batch (self , batch : Batch , dtype = "fp16" ):
@@ -187,7 +201,7 @@ def _init_batch(self, batch: Batch, dtype="fp16"):
187
201
)
188
202
self .engine .cache [batch_id ] = batch_data
189
203
190
- def _prefill_batch (self , batch ):
204
+ async def _prefill_batch (self , batch ):
191
205
"""
192
206
For all batches, no matter it is a new batch or a mini batch, we need to do prefill first.
193
207
"""
@@ -198,19 +212,20 @@ def _prefill_batch(self, batch):
198
212
req_to_out_token_id = ans
199
213
self ._add_token_id_to_req (batch , req_to_out_token_id )
200
214
has_new_finished_req = batch .mark_finished_req (self .eos_id )
201
- yield from self ._handle_finish_req (batch , has_new_finished_req )
202
-
215
+ async for item in self ._handle_finish_req (batch , has_new_finished_req ):
216
+ yield item
203
217
# delete finished reqs
204
218
205
- def _decode_batch (self , batch : Batch ):
219
+ async def _decode_batch (self , batch : Batch ):
206
220
"""
207
221
Decoding process
208
222
"""
209
223
ans = self .engine ._decode_batch (batch .batch_id )
210
224
req_to_out_token_id = ans
211
225
self ._add_token_id_to_req (batch , req_to_out_token_id )
212
226
has_new_finished_req = batch .mark_finished_req (self .eos_id )
213
- yield from self ._handle_finish_req (batch , has_new_finished_req )
227
+ async for item in self ._handle_finish_req (batch , has_new_finished_req ):
228
+ yield item
214
229
215
230
def _filter_batch (self , batch : Batch ):
216
231
batch_id = batch .batch_id
@@ -240,15 +255,15 @@ def _remove_batch(self, batch):
240
255
batch .free_self ()
241
256
del batch
242
257
243
- def _handle_finish_req (self , batch : Batch , has_new_finished_req ):
258
+ async def _handle_finish_req (self , batch : Batch , has_new_finished_req ):
244
259
if has_new_finished_req :
245
- finished_reqs = batch .filter_finished ()
260
+ finished_reqs = batch .filter_finished ()
246
261
if batch .is_clear ():
247
262
self ._remove_batch (batch )
248
263
else :
249
264
self ._filter_batch (batch )
250
- yield from self ._output_process (finished_reqs )
251
-
265
+ async for item in self ._output_process (finished_reqs ):
266
+ yield item
252
267
253
268
def _filter_runing_batch (self ):
254
269
if self .running_batch is not None and self .running_batch .is_clear ():
@@ -267,18 +282,24 @@ async def _output_process(self, finished_reqs: List[Req]):
267
282
"""
268
283
for req in finished_reqs :
269
284
output = self .tokenizer .decode (req .output_ids )
270
- yield output , req .request_id , req . output_metadata_list
285
+ yield req .prompts + output
271
286
272
287
def clean_up (self ):
273
288
# this logic should be implemented in the future.
274
289
pass
275
290
276
- async def generate (self ,request_id ,prompt_id ,sampling_params ):
291
+ async def generate (self , request_id , prompt_id , sampling_params ):
277
292
"""
278
293
Generate the output of a request.
279
294
"""
280
- self .add_input (request_id ,prompt_id ,sampling_params )
281
-
295
+
296
+ await self .add_input (request_id , prompt_id , sampling_params )
297
+
298
+
299
+ async def process_data (dbm ):
300
+ async for data in dbm .loop_for_fwd ():
301
+ print (data )
302
+
282
303
283
304
def start_dynamic_batching (args , tp_engine , waiting_req_list ):
284
305
try :
@@ -287,21 +308,13 @@ def start_dynamic_batching(args, tp_engine, waiting_req_list):
287
308
max_total_token_num = args .max_total_token_num ,
288
309
batch_max_tokens = args .batch_max_tokens ,
289
310
eos_id = args .eos_id ,
311
+ model = args .model ,
290
312
log_stats = not args .disable_log_stats ,
291
313
log_stats_interval = args .log_stats_interval ,
292
314
waiting_req_list = waiting_req_list ,
293
315
)
294
316
295
317
except Exception :
296
- batch_manager .clean_up ()
297
- raise
298
-
299
- batch_manager ._set_tokenizer (tokenizer_name = tp_engine .model .__class__ .__name__ )
300
- prod_task = asyncio .create_task (batch_manager .add_input (4 ,sampling_params = SamplingParams (),input_ids = "hello world" ))
301
-
302
- asyncio .run (prod_task )
303
-
304
- for item in batch_manager .loop_for_fwd ():
305
- print (item )
318
+ raise RuntimeError ("Failed to start dynamic batching" )
306
319
307
320
return batch_manager
0 commit comments