From 5d60967edd74cd5a7bfce917c94e8737d76fa333 Mon Sep 17 00:00:00 2001 From: Masahiro Tanaka Date: Wed, 8 Nov 2023 14:29:51 -0800 Subject: [PATCH 1/4] add message about deadlock and reserve some space in KV cache to mitigate the risk --- mii/batching/ragged_batching.py | 11 +++++++++++ 1 file changed, 11 insertions(+) diff --git a/mii/batching/ragged_batching.py b/mii/batching/ragged_batching.py index be53ee71..8ae78c43 100644 --- a/mii/batching/ragged_batching.py +++ b/mii/batching/ragged_batching.py @@ -480,6 +480,14 @@ def _do_schedule_requests(self, requests: List[RaggedRequest]) -> None: break max_blocks = free_blocks - self.scheduled_req_blocks + + # Check capacity to mitigate the deadlock risk + # We don't schedule requests when we find that a prompt is too long to fit to the KV cache + if len(r.input_tokens) > 1: + req_tokens, _ = self.inference_engine.query(r.uid, len(r.input_tokens), max_blocks) + if req_tokens < len(r.input_tokens): + break + req_tokens = min(len(r.input_tokens), max_batch_size) req_tokens, req_blocks = self.inference_engine.query(r.uid, req_tokens, max_blocks) @@ -528,6 +536,9 @@ def schedule_requests(self) -> None: self._do_schedule_requests(next_token_gen_reqs) self._do_schedule_requests(prompt_reqs) + if len(self.buffer) > 0 and len(self.scheduled_requests) == 0: + raise RuntimeError("Deadlock detected: No requests were scheduled.") + scheduled_requests_ids = set(id(r) for r in self.scheduled_requests) self.buffer = deque( [r for r in self.buffer if id(r) not in scheduled_requests_ids]) From b71341405883682a0ac0accb29c9e93316c73f40 Mon Sep 17 00:00:00 2001 From: Masahiro Tanaka Date: Wed, 8 Nov 2023 16:29:04 -0800 Subject: [PATCH 2/4] recompute when deadlock is detected --- mii/batching/ragged_batching.py | 42 ++++++++++++++++++++++++++------- 1 file changed, 34 insertions(+), 8 deletions(-) diff --git a/mii/batching/ragged_batching.py b/mii/batching/ragged_batching.py index 8ae78c43..420170ad 100644 --- a/mii/batching/ragged_batching.py +++ b/mii/batching/ragged_batching.py @@ -125,7 +125,7 @@ def from_msg(msg: Dict[str, int]) -> Self: class RaggedRequest: uid: int input_tokens: torch.Tensor - prompt_length: int + prompt_tokens: torch.Tensor seq_length: int max_length: int max_new_tokens: int @@ -141,6 +141,10 @@ class RaggedRequest: _generated_tokens: List[torch.Tensor] = field(default_factory=list) _finish_reason: GenerationFinishReason = GenerationFinishReason.NONE + @property + def prompt_length(self) -> int: + return len(self.prompt_tokens) + @property def next_token(self) -> Union[None, torch.Tensor]: return self._next_token @@ -198,6 +202,9 @@ def accumulate_generated_token(self) -> None: if not self.is_done: self._generated_tokens.append(self.next_token) + def clear_generated_token(self) -> None: + self._generated_tokens.clear() + def set_next_as_input(self) -> None: if self.next_token is not None: self.input_tokens = self.next_token.unsqueeze(0) @@ -537,11 +544,30 @@ def schedule_requests(self) -> None: self._do_schedule_requests(prompt_reqs) if len(self.buffer) > 0 and len(self.scheduled_requests) == 0: - raise RuntimeError("Deadlock detected: No requests were scheduled.") + print( + "Deadlock detected. Resetting KV cache and recomputing requests. Consider limiting number of concurrent requests or decreasing max lengths of prompts/generations." + ) + self.scheduled_requests = RaggedRequestBatch([]) + self.reset_request_status() + else: + scheduled_requests_ids = set(id(r) for r in self.scheduled_requests) + self.buffer = deque( + [r for r in self.buffer if id(r) not in scheduled_requests_ids]) + + def reset_request_status(self): + self.flush([r.uid for r in self.buffer if r.seq_length > 0]) + + new_buffer = deque() + for r in self.buffer: + new_req = copy.copy(r) + new_req.prompt_tokens = new_req.input_tokens = torch.concat( + [r.prompt_tokens] + [t.unsqueeze(0) for t in r.generated_tokens]) + new_req.seq_length = 0 + new_req.max_new_tokens = r.max_new_tokens - len(r.generated_tokens) + new_req.clear_generated_token() + new_buffer.append(new_req) - scheduled_requests_ids = set(id(r) for r in self.scheduled_requests) - self.buffer = deque( - [r for r in self.buffer if id(r) not in scheduled_requests_ids]) + self.buffer = new_buffer def make_request(self, uid: int, @@ -584,7 +610,7 @@ def make_request(self, RaggedRequest( uid=uid, input_tokens=input_tokens, - prompt_length=len(input_tokens), + prompt_tokens=input_tokens, seq_length=0, max_length=max_length, max_new_tokens=max_new_tokens, @@ -637,7 +663,7 @@ def __call__(self, inputs: Union[str, List[str]], **kwargs) -> ResponseBatch: RaggedRequest( uid=uid, input_tokens=None, - prompt_length=None, + prompt_tokens=None, seq_length=None, max_length=None, max_new_tokens=None, @@ -796,7 +822,7 @@ def destroy_session(self, RaggedRequest( uid=uid, input_tokens=None, - prompt_length=None, + prompt_tokens=None, seq_length=None, max_length=None, max_new_tokens=None, From f8f2c5431286d4d911f4a47e3838b961426a4d20 Mon Sep 17 00:00:00 2001 From: Michael Wyatt Date: Thu, 9 Nov 2023 15:37:16 -0800 Subject: [PATCH 3/4] make sure flush request happens on all ranks --- mii/batching/ragged_batching.py | 18 +++++++++++++++++- 1 file changed, 17 insertions(+), 1 deletion(-) diff --git a/mii/batching/ragged_batching.py b/mii/batching/ragged_batching.py index 89ce0729..e6fad5f0 100644 --- a/mii/batching/ragged_batching.py +++ b/mii/batching/ragged_batching.py @@ -568,8 +568,24 @@ def schedule_requests(self) -> None: self.buffer = deque( [r for r in self.buffer if id(r) not in scheduled_requests_ids]) + def _queue_flush_request(self, uid: int) -> None: + self.request_queue.put_nowait( + RaggedRequest( + uid=uid, + input_tokens=None, + prompt_tokens=None, + seq_length=None, + max_length=None, + max_new_tokens=None, + last_in_prompt=None, + post_processing=None, + stream=None, + )) + def reset_request_status(self): - self.flush([r.uid for r in self.buffer if r.seq_length > 0]) + for r in self.buffer: + if r.seq_length > 0: + self._queue_flush_request(r.uid) new_buffer = deque() for r in self.buffer: From 78b35cc6ec1bb25d16455e9ba75f4de7c78ae608 Mon Sep 17 00:00:00 2001 From: Michael Wyatt Date: Thu, 9 Nov 2023 16:54:25 -0800 Subject: [PATCH 4/4] fix merge errors --- mii/batching/ragged_batching.py | 60 +++++++++------------------------ 1 file changed, 16 insertions(+), 44 deletions(-) diff --git a/mii/batching/ragged_batching.py b/mii/batching/ragged_batching.py index d1783b50..a7693811 100644 --- a/mii/batching/ragged_batching.py +++ b/mii/batching/ragged_batching.py @@ -574,6 +574,7 @@ def schedule_requests(self) -> None: def _queue_flush_request(self, uid: int) -> None: self.request_queue.put_nowait( RaggedRequest( + tid=None, uid=uid, input_tokens=None, prompt_tokens=None, @@ -648,21 +649,19 @@ def make_request(self, assert kwargs == {}, f"Unknown keyword arguments {kwargs}" - return [ - RaggedRequest( - tid=tid, - uid=uid, - input_tokens=input_tokens, - prompt_tokens=input_tokens, - seq_length=0, - max_length=max_length, - max_new_tokens=max_new_tokens, - last_in_prompt=True, - post_processing=post_processing, - stream=stream, - ignore_eos=ignore_eos, - ) - ] + return RaggedRequest( + tid=tid, + uid=uid, + input_tokens=input_tokens, + prompt_tokens=input_tokens, + seq_length=0, + max_length=max_length, + max_new_tokens=max_new_tokens, + last_in_prompt=True, + post_processing=post_processing, + stream=stream, + ignore_eos=ignore_eos, + ) def make_response(self, generated_text: str, @@ -707,7 +706,7 @@ def __call__(self, inputs: Union[str, List[str]], **kwargs) -> ResponseBatch: while not self.result_queues[self.tid].empty(): uid, response = self._get_response() outputs.append(response) - self._flush_uid(uid) + self._queue_flush_request(uid) uids_complete_order.append(uids_running.index(uid)) uids_running.remove(uid) # Ensure final flush requests broadcast and @@ -754,21 +753,6 @@ def _bcast_responses(self, responses: ResponseBatch) -> ResponseBatch: responses = ResponseBatch([Response.from_msg(msg) for msg in data_dicts]) return responses - def _flush_uid(self, uid: int) -> None: - self.request_queue.put_nowait( - RaggedRequest( - tid=None, - uid=uid, - input_tokens=None, - prompt_length=None, - seq_length=None, - max_length=None, - max_new_tokens=None, - last_in_prompt=None, - post_processing=None, - stream=None, - )) - class MIIAsyncPipeline(RaggedBatchBase): def __init__(self, *args, **kwargs): @@ -867,17 +851,5 @@ def is_shutdown(self) -> bool: def flush_uid(self, uid: int) -> None: with self.lock: if self.is_rank_0: - self.request_queue.put_nowait( - RaggedRequest( - tid=None, - uid=uid, - input_tokens=None, - prompt_tokens=None, - seq_length=None, - max_length=None, - max_new_tokens=None, - last_in_prompt=None, - post_processing=None, - stream=None, - )) + self._queue_flush_request(uid) self.uids.remove(uid)