Skip to content

Commit

Permalink
Recompute when the deadlock is detected (#278)
Browse files Browse the repository at this point in the history
Co-authored-by: Michael Wyatt <michaelwyatt@microsoft.com>
  • Loading branch information
tohtana and mrwyattii authored Nov 10, 2023
1 parent 1867610 commit fe14d75
Showing 1 changed file with 51 additions and 35 deletions.
86 changes: 51 additions & 35 deletions mii/batching/ragged_batching.py
Original file line number Diff line number Diff line change
Expand Up @@ -134,7 +134,7 @@ class RaggedRequest:
tid: int
uid: int
input_tokens: torch.Tensor
prompt_length: int
prompt_tokens: torch.Tensor
seq_length: int
max_length: int
max_new_tokens: int
Expand All @@ -148,6 +148,10 @@ class RaggedRequest:
_generated_tokens: List[torch.Tensor] = field(default_factory=list)
_finish_reason: GenerationFinishReason = GenerationFinishReason.NONE

@property
def prompt_length(self) -> int:
return len(self.prompt_tokens)

@property
def next_token(self) -> Union[None, torch.Tensor]:
return self._next_token
Expand Down Expand Up @@ -205,6 +209,9 @@ def accumulate_generated_token(self) -> None:
if not self.is_done:
self._generated_tokens.append(self.next_token)

def clear_generated_token(self) -> None:
self._generated_tokens.clear()

def set_next_as_input(self) -> None:
if self.next_token is not None:
self.input_tokens = self.next_token.unsqueeze(0)
Expand Down Expand Up @@ -554,11 +561,47 @@ def schedule_requests(self) -> None:
self._do_schedule_requests(prompt_reqs)

if len(self.buffer) > 0 and len(self.scheduled_requests) == 0:
raise RuntimeError("Deadlock detected: No requests were scheduled.")
print(
"Deadlock detected. Resetting KV cache and recomputing requests. Consider limiting number of concurrent requests or decreasing max lengths of prompts/generations."
)
self.scheduled_requests = RaggedRequestBatch([])
self.reset_request_status()
else:
scheduled_requests_ids = set(id(r) for r in self.scheduled_requests)
self.buffer = deque(
[r for r in self.buffer if id(r) not in scheduled_requests_ids])

def _queue_flush_request(self, uid: int) -> None:
self.request_queue.put_nowait(
RaggedRequest(
tid=None,
uid=uid,
input_tokens=None,
prompt_tokens=None,
seq_length=None,
max_length=None,
max_new_tokens=None,
last_in_prompt=None,
post_processing=None,
stream=None,
))

def reset_request_status(self):
for r in self.buffer:
if r.seq_length > 0:
self._queue_flush_request(r.uid)

new_buffer = deque()
for r in self.buffer:
new_req = copy.copy(r)
new_req.prompt_tokens = new_req.input_tokens = torch.concat(
[r.prompt_tokens] + [t.unsqueeze(0) for t in r.generated_tokens])
new_req.seq_length = 0
new_req.max_new_tokens = r.max_new_tokens - len(r.generated_tokens)
new_req.clear_generated_token()
new_buffer.append(new_req)

scheduled_requests_ids = set(id(r) for r in self.scheduled_requests)
self.buffer = deque(
[r for r in self.buffer if id(r) not in scheduled_requests_ids])
self.buffer = new_buffer

def make_request(self,
tid: int,
Expand Down Expand Up @@ -610,7 +653,7 @@ def make_request(self,
tid=tid,
uid=uid,
input_tokens=input_tokens,
prompt_length=prompt_length,
prompt_tokens=input_tokens,
seq_length=0,
max_length=max_length,
max_new_tokens=max_new_tokens,
Expand Down Expand Up @@ -663,7 +706,7 @@ def __call__(self, inputs: Union[str, List[str]], **kwargs) -> ResponseBatch:
while not self.result_queues[self.tid].empty():
uid, response = self._get_response()
outputs.append(response)
self._flush_uid(uid)
self._queue_flush_request(uid)
uids_complete_order.append(uids_running.index(uid))
uids_running.remove(uid)
# Ensure final flush requests broadcast and
Expand Down Expand Up @@ -710,21 +753,6 @@ def _bcast_responses(self, responses: ResponseBatch) -> ResponseBatch:
responses = ResponseBatch([Response.from_msg(msg) for msg in data_dicts])
return responses

def _flush_uid(self, uid: int) -> None:
self.request_queue.put_nowait(
RaggedRequest(
tid=None,
uid=uid,
input_tokens=None,
prompt_length=None,
seq_length=None,
max_length=None,
max_new_tokens=None,
last_in_prompt=None,
post_processing=None,
stream=None,
))


class MIIAsyncPipeline(RaggedBatchBase):
def __init__(self, *args, **kwargs):
Expand Down Expand Up @@ -823,17 +851,5 @@ def is_shutdown(self) -> bool:
def flush_uid(self, uid: int) -> None:
with self.lock:
if self.is_rank_0:
self.request_queue.put_nowait(
RaggedRequest(
tid=None,
uid=uid,
input_tokens=None,
prompt_length=None,
seq_length=None,
max_length=None,
max_new_tokens=None,
last_in_prompt=None,
post_processing=None,
stream=None,
))
self._queue_flush_request(uid)
self.uids.remove(uid)

0 comments on commit fe14d75

Please sign in to comment.