Skip to content

Commit cf579ff

Browse files
CjhHa1isky-cdtiandiao123
authored
[Inference] Dynamic Batching Inference, online and offline (#4953)
* [inference] Dynamic Batching for Single and Multiple GPUs (#4831) * finish batch manager * 1 * first * fix * fix dynamic batching * llama infer * finish test * support different lengths generating * del prints * del prints * fix * fix bug --------- Co-authored-by: CjhHa1 <cjh18671720497outlook.com> * [inference] Async dynamic batching (#4894) * finish input and output logic * add generate * test forward * 1 * [inference]Re push async dynamic batching (#4901) * adapt to ray server * finish async * finish test * del test --------- Co-authored-by: yuehuayingxueluo <867460659@qq.com> * Revert "[inference]Re push async dynamic batching (#4901)" (#4905) This reverts commit fbf3c09. * Revert "[inference] Async dynamic batching (#4894)" This reverts commit fced140. * Revert "[inference] Async dynamic batching (#4894)" (#4909) This reverts commit fced140. * Add Ray Distributed Environment Init Scripts * support DynamicBatchManager base function * revert _set_tokenizer version * add driver async generate * add async test * fix bugs in test_ray_dist.py * add get_tokenizer.py * fix code style * fix bugs about No module named 'pydantic' in ci test * fix bugs in ci test * fix bugs in ci test * fix bugs in ci test * [infer]Add Ray Distributed Environment Init Scripts (#4911) * Revert "[inference] Async dynamic batching (#4894)" This reverts commit fced140. * Add Ray Distributed Environment Init Scripts * support DynamicBatchManager base function * revert _set_tokenizer version * add driver async generate * add async test * fix bugs in test_ray_dist.py * add get_tokenizer.py * fix code style * fix bugs about No module named 'pydantic' in ci test * fix bugs in ci test * fix bugs in ci test * fix bugs in ci test * support dynamic batch for bloom model and is_running function * [Inference]Test for new Async engine (#4935) * infer engine * infer engine * test engine * test engine * new manager * change step * add * test * fix * fix * finish test * finish test * finish test * finish test * add license --------- Co-authored-by: yuehuayingxueluo <867460659@qq.com> * add assertion for config (#4947) * [Inference] Finish dynamic batching offline test (#4948) * test * fix test * fix quant * add default * fix * fix some bugs * fix some bugs * fix * fix bug * fix bugs * reset param --------- Co-authored-by: yuehuayingxueluo <867460659@qq.com> Co-authored-by: Cuiqing Li <lixx3527@gmail.com> Co-authored-by: CjhHa1 <cjh18671720497outlook.com>
1 parent 4e4a10c commit cf579ff

30 files changed

+2005
-92
lines changed

colossalai/inference/async_engine.py

+133
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,133 @@
1+
import asyncio
2+
3+
from colossalai.inference.dynamic_batching.ray_dist_init import Driver
4+
5+
from .dynamic_batching.io_struct import RequestOutput
6+
from .dynamic_batching.sampling_params import SamplingParams
7+
8+
9+
class RequestTracker:
10+
"""
11+
A class for trace down all the requests, abstraction for async
12+
"""
13+
14+
def __init__(self) -> None:
15+
self._requests: asyncio.Queue[str] = asyncio.Queue()
16+
self._finished_requests: asyncio.Queue[RequestOutput] = asyncio.Queue()
17+
self.new_requests_event = None
18+
19+
def __contains__(self, item):
20+
return item in self._requests
21+
22+
def init_event(self):
23+
self.new_requests_event = asyncio.Event()
24+
25+
def add_request(self, request_id: str):
26+
"""Add a request to be sent to the engine on the next background
27+
loop iteration."""
28+
self._requests.put_nowait(request_id)
29+
self.new_requests_event.set() # NOTE: we may find a better way to clear this event
30+
31+
def add_stop(self):
32+
"""
33+
Add a StopIteration flag to stop async generator.
34+
"""
35+
self._finished_requests.put_nowait(StopIteration)
36+
self.new_requests_event.clear()
37+
38+
def process_request_output(self, request_output: RequestOutput) -> None:
39+
"""Process a request output from the engine."""
40+
self._finished_requests.put_nowait(request_output)
41+
42+
async def wait_for_new_requests(self):
43+
await self.new_requests_event.wait()
44+
45+
def __aiter__(self):
46+
return self
47+
48+
async def __anext__(self) -> RequestOutput:
49+
result = await self._finished_requests.get()
50+
# print("result of ", result)
51+
if result is StopIteration:
52+
raise StopAsyncIteration
53+
return result
54+
55+
56+
class Async_Engine:
57+
58+
"""
59+
Use an engine to launch RAY Driver --> RAY Worker --> Async_Manager
60+
Background loop: inference reqs in waiting list (Listen)
61+
Request Tracker: manage incoming requests and restore finished ones
62+
Generate: exposed func for add new input and return finished ones
63+
"""
64+
65+
def __init__(
66+
self,
67+
router_config,
68+
engine_config,
69+
start_engine_loop: bool = True,
70+
) -> None:
71+
self.driver = Driver(router_config=router_config, engine_config=engine_config)
72+
self.background_loop = None
73+
self.start_engine_loop = start_engine_loop
74+
self._request_tracker = RequestTracker()
75+
76+
def _step(self):
77+
"""
78+
Logic for handling requests
79+
"""
80+
request_outputs = self.driver.step()
81+
if request_outputs is not None:
82+
for request_output in request_outputs:
83+
self._request_tracker.process_request_output(request_output)
84+
self._request_tracker.add_stop()
85+
86+
def abort_request(self, request_id: str):
87+
self.driver.abort(request_id)
88+
89+
def _has_requests_in_progress(self):
90+
return self.driver.is_running()
91+
92+
async def run_loop_fwd(self):
93+
has_requests_in_progress = self._has_requests_in_progress()
94+
while True:
95+
if not has_requests_in_progress:
96+
await self._request_tracker.wait_for_new_requests()
97+
self._step()
98+
await asyncio.sleep(0)
99+
100+
@property
101+
def is_running(self):
102+
return self.background_loop is not None and not self.background_loop.done()
103+
104+
def start_background_loop(self):
105+
if self.is_running:
106+
raise RuntimeError("Background loop is already running.")
107+
108+
self._request_tracker.init_event()
109+
110+
self.background_loop_unshielded = asyncio.get_event_loop().create_task(self.run_loop_fwd())
111+
self.background_loop = asyncio.shield(self.background_loop_unshielded)
112+
113+
async def add_request(self, request_id: str, prompt: str, sampling_params: SamplingParams):
114+
self.driver.add_input(request_id, prompt, sampling_params)
115+
self._request_tracker.add_request(request_id)
116+
117+
async def generate(self, request_id: str, prompt: str, sampling_params: SamplingParams):
118+
"""
119+
The only exposed func, adding new request and return a async generator that yields the existing results.
120+
"""
121+
try:
122+
if not self.is_running:
123+
self.start_background_loop()
124+
125+
await self.add_request(request_id, prompt, sampling_params)
126+
127+
async for request_output in self._request_tracker:
128+
yield request_output
129+
130+
except (Exception, asyncio.CancelledError) as e:
131+
# If there is an exception or coroutine is cancelled, abort the request.
132+
self.abort_request(request_id)
133+
raise e

colossalai/inference/async_manager.py

+151
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,151 @@
1+
from typing import List
2+
3+
from .dynamic_batching.io_struct import Batch, Req, RequestOutput
4+
from .manager import DynamicBatchManager
5+
from .tensor_parallel import TPInferEngine
6+
7+
8+
class Async_DynamicBatchManager(DynamicBatchManager):
9+
def __init__(
10+
self,
11+
tp_engine: TPInferEngine,
12+
max_total_token_num: int,
13+
batch_max_tokens: int,
14+
model: str,
15+
tokenizer=None,
16+
eos_id=None,
17+
log_stats=True,
18+
log_stats_interval=10,
19+
running_batch: Batch = None,
20+
waiting_req_list: List = [],
21+
):
22+
"""
23+
Args: tp_engine : The tp engine that dynamic batch manager hold, defined before dynamic batch manager
24+
max_total_token_num : max_total_token_num for memory manager, default to: max batch size * (max input len + max output len)
25+
batch_max_tokens : max tokens of one batch, default to (max input + output len) * num_requests
26+
running_max_req_size : max request size of running batch, equals to MAX_BATCH_SIZE of tp engine
27+
eos_id : The end token of a seq
28+
model: the model weight dir path, the app will load config, weights and tokenizer from this dir
29+
log_stats : whether to log stats
30+
log_stats_interval : log stats interval
31+
running_batch : running batch
32+
waiting_req_list : list of waiting requests, initialized before dynamic batch manager
33+
"""
34+
super().__init__(
35+
tp_engine,
36+
max_total_token_num,
37+
batch_max_tokens,
38+
model,
39+
tokenizer,
40+
eos_id,
41+
log_stats,
42+
log_stats_interval,
43+
running_batch,
44+
waiting_req_list,
45+
)
46+
47+
def _step(self):
48+
"""
49+
Logic for handling requests
50+
"""
51+
has_new_finished = False
52+
if self.running_batch is None:
53+
new_batch = self.req_queue.generate_new_batch(self.running_batch)
54+
if new_batch is not None:
55+
self.stats_tool.count_prompt_tokens(new_batch)
56+
self.running_batch = new_batch
57+
has_new_finished, outputs = self._prefill_batch(self.running_batch)
58+
self._filter_runing_batch()
59+
self.has_wait_tokens = 0
60+
61+
else:
62+
if self.has_wait_tokens < self.max_wait_tokens:
63+
self.stats_tool.count_output_tokens(self.running_batch)
64+
has_new_finished, outputs = self._decode_batch(self.running_batch)
65+
self._filter_runing_batch()
66+
self.has_wait_tokens += 1
67+
68+
else:
69+
new_mini_batch = self.req_queue.generate_new_batch(self.running_batch)
70+
if new_mini_batch is not None:
71+
self.stats_tool.count_prompt_tokens(new_mini_batch)
72+
has_new_finished, outputs = self._prefill_batch(new_mini_batch)
73+
if not new_mini_batch.is_clear():
74+
self._merge_batch(self.running_batch, new_mini_batch)
75+
self.running_batch.merge(new_mini_batch)
76+
self.has_wait_tokens = 0
77+
78+
else:
79+
self.stats_tool.count_output_tokens(self.running_batch)
80+
has_new_finished, outputs = self._decode_batch(self.running_batch)
81+
self._filter_runing_batch()
82+
self.has_wait_tokens += 1
83+
84+
if has_new_finished:
85+
return outputs
86+
return None
87+
88+
def _prefill_batch(self, batch):
89+
"""
90+
For all batches, no matter it is a new batch or a mini batch, we need to do prefill first.
91+
"""
92+
self._init_batch(batch)
93+
94+
# TODO: figure out if cache and batch id is needed
95+
ans = self.engine._prefill_batch(batch.batch_id)
96+
req_to_out_token_id = ans
97+
self._add_token_id_to_req(batch, req_to_out_token_id)
98+
has_new_finished_req = batch.mark_finished_req(self.eos_id, self.engine.max_output_len)
99+
outputs = self._handle_finish_req(batch, has_new_finished_req)
100+
return has_new_finished_req, outputs
101+
# delete finished reqs
102+
103+
def _decode_batch(self, batch: Batch):
104+
"""
105+
Decoding process
106+
"""
107+
ans = self.engine._decode_batch(batch.batch_id)
108+
req_to_out_token_id = ans
109+
self._add_token_id_to_req(batch, req_to_out_token_id)
110+
has_new_finished_req = batch.mark_finished_req(self.eos_id, self.engine.max_output_len)
111+
outputs = self._handle_finish_req(batch, has_new_finished_req)
112+
return has_new_finished_req, outputs
113+
114+
def _handle_finish_req(self, batch: Batch, has_new_finished_req):
115+
if has_new_finished_req:
116+
finished_reqs = batch.filter_finished()
117+
if batch.is_clear():
118+
self._remove_batch(batch)
119+
else:
120+
self._filter_batch(batch)
121+
return self._output_process(finished_reqs)
122+
return None
123+
124+
def _output_process(self, finished_reqs: List[Req]):
125+
"""
126+
Process the output of a batch.
127+
"""
128+
outputs = []
129+
for req in finished_reqs:
130+
output = self.tokenizer.decode(req.output_ids)
131+
outputs.append(RequestOutput(req.request_id, req.prompts, req.prompt_ids, output))
132+
return outputs
133+
134+
135+
def start_dynamic_batching(args, tp_engine, waiting_req_list):
136+
try:
137+
batch_manager = Async_DynamicBatchManager(
138+
tp_engine=tp_engine,
139+
max_total_token_num=args.max_total_token_num,
140+
batch_max_tokens=args.batch_max_tokens,
141+
eos_id=args.eos_id,
142+
model=args.model,
143+
log_stats=not args.disable_log_stats,
144+
log_stats_interval=args.log_stats_interval,
145+
waiting_req_list=waiting_req_list,
146+
)
147+
148+
except Exception:
149+
raise Exception
150+
151+
return batch_manager

colossalai/inference/dynamic_batching/__init__.py

Whitespace-only changes.
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,40 @@
1+
"""
2+
Motivated by VllM (https://github.com/vllm-project/vllm), This module is trying to resolve the tokenizer issue.
3+
4+
license: MIT, see LICENSE for more details.
5+
"""
6+
7+
from transformers import AutoTokenizer
8+
9+
_FAST_LLAMA_TOKENIZER = "hf-internal-testing/llama-tokenizer"
10+
11+
12+
def get_tokenizer(
13+
tokenizer=None,
14+
tokenizer_name: str = "",
15+
trust_remote_code: bool = False,
16+
use_fast: bool = True,
17+
):
18+
if tokenizer is not None:
19+
tokenizer = tokenizer
20+
else:
21+
if "llama" in tokenizer_name.lower() and use_fast == True:
22+
print(
23+
"For some LLaMA-based models, initializing the fast tokenizer may "
24+
"take a long time. To eliminate the initialization time, consider "
25+
f"using '{_FAST_LLAMA_TOKENIZER}' instead of the original "
26+
"tokenizer. This is done automatically in Colossalai."
27+
)
28+
29+
tokenizer_name = _FAST_LLAMA_TOKENIZER
30+
31+
try:
32+
tokenizer = AutoTokenizer.from_pretrained(
33+
tokenizer_name, use_fast=use_fast, trust_remote_code=trust_remote_code
34+
)
35+
except TypeError:
36+
use_fast = False
37+
tokenizer = AutoTokenizer.from_pretrained(
38+
tokenizer_name, use_fast=use_fast, trust_remote_code=trust_remote_code
39+
)
40+
return tokenizer

0 commit comments

Comments
 (0)