From fe14d75197e67f2a97093091623083d9034b4bd4 Mon Sep 17 00:00:00 2001 From: Masahiro Tanaka <81312776+tohtana@users.noreply.github.com> Date: Thu, 9 Nov 2023 16:59:51 -0800 Subject: [PATCH] Recompute when the deadlock is detected (#278) Co-authored-by: Michael Wyatt --- mii/batching/ragged_batching.py | 86 +++++++++++++++++++-------------- 1 file changed, 51 insertions(+), 35 deletions(-) diff --git a/mii/batching/ragged_batching.py b/mii/batching/ragged_batching.py index 942f3277..a7693811 100644 --- a/mii/batching/ragged_batching.py +++ b/mii/batching/ragged_batching.py @@ -134,7 +134,7 @@ class RaggedRequest: tid: int uid: int input_tokens: torch.Tensor - prompt_length: int + prompt_tokens: torch.Tensor seq_length: int max_length: int max_new_tokens: int @@ -148,6 +148,10 @@ class RaggedRequest: _generated_tokens: List[torch.Tensor] = field(default_factory=list) _finish_reason: GenerationFinishReason = GenerationFinishReason.NONE + @property + def prompt_length(self) -> int: + return len(self.prompt_tokens) + @property def next_token(self) -> Union[None, torch.Tensor]: return self._next_token @@ -205,6 +209,9 @@ def accumulate_generated_token(self) -> None: if not self.is_done: self._generated_tokens.append(self.next_token) + def clear_generated_token(self) -> None: + self._generated_tokens.clear() + def set_next_as_input(self) -> None: if self.next_token is not None: self.input_tokens = self.next_token.unsqueeze(0) @@ -554,11 +561,47 @@ def schedule_requests(self) -> None: self._do_schedule_requests(prompt_reqs) if len(self.buffer) > 0 and len(self.scheduled_requests) == 0: - raise RuntimeError("Deadlock detected: No requests were scheduled.") + print( + "Deadlock detected. Resetting KV cache and recomputing requests. Consider limiting number of concurrent requests or decreasing max lengths of prompts/generations." + ) + self.scheduled_requests = RaggedRequestBatch([]) + self.reset_request_status() + else: + scheduled_requests_ids = set(id(r) for r in self.scheduled_requests) + self.buffer = deque( + [r for r in self.buffer if id(r) not in scheduled_requests_ids]) + + def _queue_flush_request(self, uid: int) -> None: + self.request_queue.put_nowait( + RaggedRequest( + tid=None, + uid=uid, + input_tokens=None, + prompt_tokens=None, + seq_length=None, + max_length=None, + max_new_tokens=None, + last_in_prompt=None, + post_processing=None, + stream=None, + )) + + def reset_request_status(self): + for r in self.buffer: + if r.seq_length > 0: + self._queue_flush_request(r.uid) + + new_buffer = deque() + for r in self.buffer: + new_req = copy.copy(r) + new_req.prompt_tokens = new_req.input_tokens = torch.concat( + [r.prompt_tokens] + [t.unsqueeze(0) for t in r.generated_tokens]) + new_req.seq_length = 0 + new_req.max_new_tokens = r.max_new_tokens - len(r.generated_tokens) + new_req.clear_generated_token() + new_buffer.append(new_req) - scheduled_requests_ids = set(id(r) for r in self.scheduled_requests) - self.buffer = deque( - [r for r in self.buffer if id(r) not in scheduled_requests_ids]) + self.buffer = new_buffer def make_request(self, tid: int, @@ -610,7 +653,7 @@ def make_request(self, tid=tid, uid=uid, input_tokens=input_tokens, - prompt_length=prompt_length, + prompt_tokens=input_tokens, seq_length=0, max_length=max_length, max_new_tokens=max_new_tokens, @@ -663,7 +706,7 @@ def __call__(self, inputs: Union[str, List[str]], **kwargs) -> ResponseBatch: while not self.result_queues[self.tid].empty(): uid, response = self._get_response() outputs.append(response) - self._flush_uid(uid) + self._queue_flush_request(uid) uids_complete_order.append(uids_running.index(uid)) uids_running.remove(uid) # Ensure final flush requests broadcast and @@ -710,21 +753,6 @@ def _bcast_responses(self, responses: ResponseBatch) -> ResponseBatch: responses = ResponseBatch([Response.from_msg(msg) for msg in data_dicts]) return responses - def _flush_uid(self, uid: int) -> None: - self.request_queue.put_nowait( - RaggedRequest( - tid=None, - uid=uid, - input_tokens=None, - prompt_length=None, - seq_length=None, - max_length=None, - max_new_tokens=None, - last_in_prompt=None, - post_processing=None, - stream=None, - )) - class MIIAsyncPipeline(RaggedBatchBase): def __init__(self, *args, **kwargs): @@ -823,17 +851,5 @@ def is_shutdown(self) -> bool: def flush_uid(self, uid: int) -> None: with self.lock: if self.is_rank_0: - self.request_queue.put_nowait( - RaggedRequest( - tid=None, - uid=uid, - input_tokens=None, - prompt_length=None, - seq_length=None, - max_length=None, - max_new_tokens=None, - last_in_prompt=None, - post_processing=None, - stream=None, - )) + self._queue_flush_request(uid) self.uids.remove(uid)