Skip to content

Commit fbf3c09

Browse files
CjhHa1isky-cd
andauthored
[inference]Re push async dynamic batching (#4901)
* adapt to ray server * finish async * finish test * del test --------- Co-authored-by: yuehuayingxueluo <867460659@qq.com>
1 parent fced140 commit fbf3c09

File tree

4 files changed

+107
-109
lines changed

4 files changed

+107
-109
lines changed

colossalai/inference/dynamic_batching/io_struct.py

+10-5
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44

55

66
class Req:
7-
def __init__(self, request_id, prompt_ids, sample_params: SamplingParams):
7+
def __init__(self, request_id, prompt_ids, sample_params: SamplingParams, prompts: str = ""):
88
self.request_id = request_id
99
self.prompt_ids = prompt_ids
1010
self.input_len = len(prompt_ids)
@@ -14,6 +14,7 @@ def __init__(self, request_id, prompt_ids, sample_params: SamplingParams):
1414
self.output_metadata_list = []
1515
self.has_generate_finished = False
1616
self.aborted = False
17+
self.prompts = prompts
1718

1819
def to_rpc_obj(self):
1920
return {
@@ -36,7 +37,11 @@ def stop_sequences_matched(self):
3637
if self.sample_params.stop_sequences is not None:
3738
for stop_token_ids in self.sample_params.stop_sequences:
3839
stop_len = len(stop_token_ids)
39-
if stop_len > 0 and len(self.output_ids) >= stop_len and all(self.output_ids[-(stop_len - i)] == stop_token_ids[i] for i in range(stop_len)):
40+
if (
41+
stop_len > 0
42+
and len(self.output_ids) >= stop_len
43+
and all(self.output_ids[-(stop_len - i)] == stop_token_ids[i] for i in range(stop_len))
44+
):
4045
return True
4146
return False
4247

@@ -102,7 +107,7 @@ def mark_finished_req(self, eos_id):
102107
has_new_finish = True
103108
return has_new_finish
104109

105-
def filter_finished(self)->List[Req]:
110+
def filter_finished(self) -> List[Req]:
106111
"""
107112
Filter finished requests from the batch, the finished ones will be removed from 'reqs'.
108113
"""
@@ -111,9 +116,9 @@ def filter_finished(self)->List[Req]:
111116
finished_req = []
112117
for req in self.reqs:
113118
if not req.has_generate_finished:
114-
unfinished_req.append(req)
119+
unfinished_req.append(req)
115120
else:
116-
finished_req.append(req)
121+
finished_req.append(req)
117122
self.reqs = unfinished_req
118123
self.id_to_reqs = {req.request_id: req for req in self.reqs}
119124
return finished_req

colossalai/inference/manager.py

+76-63
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
1-
import time
2-
from typing import List
31
import asyncio
2+
from typing import List
3+
4+
from transformers import AutoTokenizer
45

56
from .dynamic_batching.infer_batch import InferBatch
67
from .dynamic_batching.io_struct import Batch, Req
@@ -9,16 +10,17 @@
910
from .dynamic_batching.stats import Stats
1011
from .tensor_parallel import TPInferEngine
1112

12-
from transformers import AutoTokenizer
1313
_FAST_LLAMA_TOKENIZER = "hf-internal-testing/llama-tokenizer"
1414

15+
1516
class DynamicBatchManager:
1617
def __init__(
1718
self,
1819
tp_engine: TPInferEngine,
1920
max_total_token_num,
2021
batch_max_tokens,
2122
eos_id,
23+
model,
2224
log_stats=True,
2325
log_stats_interval=10,
2426
running_batch: Batch = None,
@@ -30,6 +32,7 @@ def __init__(
3032
batch_max_tokens : max tokens of one batch, default to (max input + output len) * num_requests
3133
running_max_req_size : max request size of running batch, equals to MAX_BATCH_SIZE of tp engine
3234
eos_id : The end token of a seq
35+
model: the model weight dir path, the app will load config, weights and tokenizer from this dir
3336
log_stats : whether to log stats
3437
log_stats_interval : log stats interval
3538
running_batch : running batch
@@ -45,32 +48,32 @@ def __init__(
4548
self.eos_id = eos_id
4649
self.has_wait_tokens = 0
4750
self.max_wait_tokens = 10
48-
51+
self.model = model
52+
4953
self.stats_tool = Stats(log_stats, log_stats_interval)
5054
self.mem_usage_interval = log_stats_interval * 2
55+
self._set_tokenizer(tokenizer_name=self.model)
5156

52-
def add_req(self, prompt_ids: List[int], sampling_params: SamplingParams, request_id: str):
57+
async def add_req(self, request_id, prompt_ids: List[int], sampling_params: SamplingParams, prompts: str = ""):
5358
"""
5459
Add new request to req queue, during initialization all requests are held in waiting list.
5560
"""
56-
req = Req(request_id, prompt_ids, sampling_params)
61+
req = Req(request_id, prompt_ids, sampling_params, prompts)
5762
self.req_queue.append(req)
5863
return
5964

60-
def add_input(self, request_id, sampling_params, input_ids):
65+
async def add_input(self, request_id, sampling_params, prompts):
6166
"""
6267
Encode and Add new input to req queue. support one sequence input for now.
6368
"""
64-
prompt_ids = self.tokenizer.encode(input_ids)
69+
prompt_ids = self.tokenizer.encode(prompts)
6570
prompt_len = len(prompt_ids)
6671
if prompt_len > self.engine.max_input_len:
67-
raise ValueError(
68-
f"the input prompt token len {prompt_len} is too long > {self.engine.max_input_len}"
69-
)
72+
raise ValueError(f"the input prompt token len {prompt_len} is too long > {self.engine.max_input_len}")
7073
sampling_params.stop_sentences_to_token_ids(self.tokenizer)
71-
self.add_req(prompt_ids, sampling_params, request_id)
74+
self.add_req(request_id, prompt_ids, sampling_params, prompts)
7275
return
73-
76+
7477
def abort(self, request_id):
7578
if self.running_batch is not None:
7679
for req in self.running_batch.reqs:
@@ -88,10 +91,15 @@ async def loop_for_fwd(self):
8891
The main loop for a dynamic batching process.
8992
"""
9093
counter_count = 0
91-
#self.running_batch is not None or self.req_queue.waiting_req_list
94+
# self.running_batch is not None or self.req_queue.waiting_req_list
9295
while True:
93-
async for item in self._step():
94-
yield item
96+
if self.running_batch is not None or self.req_queue.waiting_req_list:
97+
async for result in self._step():
98+
yield result
99+
else:
100+
# need to wait for new requests
101+
await asyncio.sleep(0.1)
102+
continue
95103
counter_count += 1
96104
if self.running_batch is not None:
97105
if counter_count % self.mem_usage_interval == 0:
@@ -103,30 +111,33 @@ async def loop_for_fwd(self):
103111
)
104112
self.stats_tool.print_stats()
105113

106-
if self.running_batch is None:
107-
time.sleep(0.1) # 10ms
108-
109-
def _set_tokenizer(self, tokenizer=None, tokenizer_name: str = "", trust_remote_code: bool = False, use_fast:bool = True,):
114+
def _set_tokenizer(
115+
self, tokenizer=None, tokenizer_name: str = "", trust_remote_code: bool = False, use_fast: bool = True
116+
):
110117
if tokenizer is not None:
111-
self.tokenizer = tokenizer
118+
self.tokenizer = tokenizer
112119
else:
113120
if "llama" in tokenizer_name.lower() and use_fast == True:
114121
print(
115-
"For some LLaMA-based models, initializing the fast tokenizer may "
116-
"take a long time. To eliminate the initialization time, consider "
117-
f"using '{_FAST_LLAMA_TOKENIZER}' instead of the original "
118-
"tokenizer. This is done automatically in Colossalai.")
119-
120-
tokenizer_name = _FAST_LLAMA_TOKENIZER
121-
122-
try:
123-
self.tokenizer = AutoTokenizer.from_pretrained(tokenizer_name, use_fast=use_fast,trust_remote_code=trust_remote_code)
124-
except TypeError as e:
122+
"For some LLaMA-based models, initializing the fast tokenizer may "
123+
"take a long time. To eliminate the initialization time, consider "
124+
f"using '{_FAST_LLAMA_TOKENIZER}' instead of the original "
125+
"tokenizer. This is done automatically in Colossalai."
126+
)
127+
128+
tokenizer_name = _FAST_LLAMA_TOKENIZER
129+
130+
try:
131+
self.tokenizer = AutoTokenizer.from_pretrained(
132+
tokenizer_name, use_fast=use_fast, trust_remote_code=trust_remote_code
133+
)
134+
except TypeError:
125135
use_fast = False
126-
self.tokenizer = AutoTokenizer.from_pretrained(tokenizer_name, use_fast=use_fast,trust_remote_code=trust_remote_code)
127-
136+
self.tokenizer = AutoTokenizer.from_pretrained(
137+
tokenizer_name, use_fast=use_fast, trust_remote_code=trust_remote_code
138+
)
128139

129-
def _step(self):
140+
async def _step(self):
130141
"""
131142
Logic for handling requests
132143
"""
@@ -136,33 +147,36 @@ def _step(self):
136147
if new_batch is not None:
137148
self.stats_tool.count_prompt_tokens(new_batch)
138149
self.running_batch = new_batch
139-
yield from self._prefill_batch(self.running_batch)
150+
async for item in self._prefill_batch(self.running_batch):
151+
yield item
140152
self._filter_runing_batch()
141153
self.has_wait_tokens = 0
142154
return
143155

144156
if self.has_wait_tokens < self.max_wait_tokens:
145157
self.stats_tool.count_output_tokens(self.running_batch)
146-
yield from self._decode_batch(self.running_batch)
158+
self._decode_batch(self.running_batch)
147159
self._filter_runing_batch()
148160
self.has_wait_tokens += 1
149161
return
150162
else:
151163
new_mini_batch = self.req_queue.generate_new_batch(self.running_batch)
152164
if new_mini_batch is not None:
153165
self.stats_tool.count_prompt_tokens(new_mini_batch)
154-
yield from self._prefill_batch(new_mini_batch)
166+
async for item in self._prefill_batch(new_mini_batch):
167+
yield item
155168
if not new_mini_batch.is_clear():
156169
self._merge_batch(self.running_batch, new_mini_batch)
157170
self.running_batch.merge(new_mini_batch)
158171
self.has_wait_tokens = 0
159-
172+
160173
else:
161174
self.stats_tool.count_output_tokens(self.running_batch)
162-
yield from self._decode_batch(self.running_batch)
175+
async for item in self._decode_batch(self.running_batch):
176+
yield item
163177
self._filter_runing_batch()
164178
self.has_wait_tokens += 1
165-
179+
166180
return
167181

168182
def _init_batch(self, batch: Batch, dtype="fp16"):
@@ -187,7 +201,7 @@ def _init_batch(self, batch: Batch, dtype="fp16"):
187201
)
188202
self.engine.cache[batch_id] = batch_data
189203

190-
def _prefill_batch(self, batch):
204+
async def _prefill_batch(self, batch):
191205
"""
192206
For all batches, no matter it is a new batch or a mini batch, we need to do prefill first.
193207
"""
@@ -198,19 +212,20 @@ def _prefill_batch(self, batch):
198212
req_to_out_token_id = ans
199213
self._add_token_id_to_req(batch, req_to_out_token_id)
200214
has_new_finished_req = batch.mark_finished_req(self.eos_id)
201-
yield from self._handle_finish_req(batch, has_new_finished_req)
202-
215+
async for item in self._handle_finish_req(batch, has_new_finished_req):
216+
yield item
203217
# delete finished reqs
204218

205-
def _decode_batch(self, batch: Batch):
219+
async def _decode_batch(self, batch: Batch):
206220
"""
207221
Decoding process
208222
"""
209223
ans = self.engine._decode_batch(batch.batch_id)
210224
req_to_out_token_id = ans
211225
self._add_token_id_to_req(batch, req_to_out_token_id)
212226
has_new_finished_req = batch.mark_finished_req(self.eos_id)
213-
yield from self._handle_finish_req(batch, has_new_finished_req)
227+
async for item in self._handle_finish_req(batch, has_new_finished_req):
228+
yield item
214229

215230
def _filter_batch(self, batch: Batch):
216231
batch_id = batch.batch_id
@@ -240,15 +255,15 @@ def _remove_batch(self, batch):
240255
batch.free_self()
241256
del batch
242257

243-
def _handle_finish_req(self, batch: Batch, has_new_finished_req):
258+
async def _handle_finish_req(self, batch: Batch, has_new_finished_req):
244259
if has_new_finished_req:
245-
finished_reqs=batch.filter_finished()
260+
finished_reqs = batch.filter_finished()
246261
if batch.is_clear():
247262
self._remove_batch(batch)
248263
else:
249264
self._filter_batch(batch)
250-
yield from self._output_process(finished_reqs)
251-
265+
async for item in self._output_process(finished_reqs):
266+
yield item
252267

253268
def _filter_runing_batch(self):
254269
if self.running_batch is not None and self.running_batch.is_clear():
@@ -267,18 +282,24 @@ async def _output_process(self, finished_reqs: List[Req]):
267282
"""
268283
for req in finished_reqs:
269284
output = self.tokenizer.decode(req.output_ids)
270-
yield output, req.request_id, req.output_metadata_list
285+
yield req.prompts + output
271286

272287
def clean_up(self):
273288
# this logic should be implemented in the future.
274289
pass
275290

276-
async def generate(self,request_id,prompt_id,sampling_params):
291+
async def generate(self, request_id, prompt_id, sampling_params):
277292
"""
278293
Generate the output of a request.
279294
"""
280-
self.add_input(request_id,prompt_id,sampling_params)
281-
295+
296+
await self.add_input(request_id, prompt_id, sampling_params)
297+
298+
299+
async def process_data(dbm):
300+
async for data in dbm.loop_for_fwd():
301+
print(data)
302+
282303

283304
def start_dynamic_batching(args, tp_engine, waiting_req_list):
284305
try:
@@ -287,21 +308,13 @@ def start_dynamic_batching(args, tp_engine, waiting_req_list):
287308
max_total_token_num=args.max_total_token_num,
288309
batch_max_tokens=args.batch_max_tokens,
289310
eos_id=args.eos_id,
311+
model=args.model,
290312
log_stats=not args.disable_log_stats,
291313
log_stats_interval=args.log_stats_interval,
292314
waiting_req_list=waiting_req_list,
293315
)
294316

295317
except Exception:
296-
batch_manager.clean_up()
297-
raise
298-
299-
batch_manager._set_tokenizer(tokenizer_name = tp_engine.model.__class__.__name__)
300-
prod_task = asyncio.create_task(batch_manager.add_input(4,sampling_params=SamplingParams(),input_ids="hello world"))
301-
302-
asyncio.run(prod_task)
303-
304-
for item in batch_manager.loop_for_fwd():
305-
print(item)
318+
raise RuntimeError("Failed to start dynamic batching")
306319

307320
return batch_manager

colossalai/inference/test_async.py

-33
This file was deleted.

0 commit comments

Comments
 (0)