Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Recompute when the deadlock is detected #278

Merged
merged 6 commits into from
Nov 10, 2023
Merged
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
42 changes: 34 additions & 8 deletions mii/batching/ragged_batching.py
Original file line number Diff line number Diff line change
Expand Up @@ -133,7 +133,7 @@ def from_msg(msg: Dict[str, int]) -> Self:
class RaggedRequest:
uid: int
input_tokens: torch.Tensor
prompt_length: int
prompt_tokens: torch.Tensor
seq_length: int
max_length: int
max_new_tokens: int
Expand All @@ -147,6 +147,10 @@ class RaggedRequest:
_generated_tokens: List[torch.Tensor] = field(default_factory=list)
_finish_reason: GenerationFinishReason = GenerationFinishReason.NONE

@property
def prompt_length(self) -> int:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In the case we have deadlock, the prompt_length will be incorrect on the output. Perhaps we need to store original_prompt_length or something like that. Don't worry about this for now. I have plans to expand the output details and can handle this in a future PR.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I see, you are right. I would appreciate it if you could fix it on your PR

return len(self.prompt_tokens)

@property
def next_token(self) -> Union[None, torch.Tensor]:
return self._next_token
Expand Down Expand Up @@ -204,6 +208,9 @@ def accumulate_generated_token(self) -> None:
if not self.is_done:
self._generated_tokens.append(self.next_token)

def clear_generated_token(self) -> None:
self._generated_tokens.clear()

def set_next_as_input(self) -> None:
if self.next_token is not None:
self.input_tokens = self.next_token.unsqueeze(0)
Expand Down Expand Up @@ -551,11 +558,30 @@ def schedule_requests(self) -> None:
self._do_schedule_requests(prompt_reqs)

if len(self.buffer) > 0 and len(self.scheduled_requests) == 0:
raise RuntimeError("Deadlock detected: No requests were scheduled.")
print(
"Deadlock detected. Resetting KV cache and recomputing requests. Consider limiting number of concurrent requests or decreasing max lengths of prompts/generations."
)
self.scheduled_requests = RaggedRequestBatch([])
self.reset_request_status()
else:
scheduled_requests_ids = set(id(r) for r in self.scheduled_requests)
self.buffer = deque(
[r for r in self.buffer if id(r) not in scheduled_requests_ids])

def reset_request_status(self):
self.flush([r.uid for r in self.buffer if r.seq_length > 0])

new_buffer = deque()
for r in self.buffer:
new_req = copy.copy(r)
new_req.prompt_tokens = new_req.input_tokens = torch.concat(
[r.prompt_tokens] + [t.unsqueeze(0) for t in r.generated_tokens])
new_req.seq_length = 0
new_req.max_new_tokens = r.max_new_tokens - len(r.generated_tokens)
new_req.clear_generated_token()
new_buffer.append(new_req)

scheduled_requests_ids = set(id(r) for r in self.scheduled_requests)
self.buffer = deque(
[r for r in self.buffer if id(r) not in scheduled_requests_ids])
self.buffer = new_buffer

def make_request(self,
uid: int,
Expand Down Expand Up @@ -609,7 +635,7 @@ def make_request(self,
RaggedRequest(
uid=uid,
input_tokens=input_tokens,
prompt_length=prompt_length,
prompt_tokens=input_tokens,
seq_length=0,
max_length=max_length,
max_new_tokens=max_new_tokens,
Expand Down Expand Up @@ -660,7 +686,7 @@ def __call__(self, inputs: Union[str, List[str]], **kwargs) -> ResponseBatch:
RaggedRequest(
uid=uid,
input_tokens=None,
prompt_length=None,
prompt_tokens=None,
seq_length=None,
max_length=None,
max_new_tokens=None,
Expand Down Expand Up @@ -823,7 +849,7 @@ def destroy_session(self,
RaggedRequest(
uid=uid,
input_tokens=None,
prompt_length=None,
prompt_tokens=None,
seq_length=None,
max_length=None,
max_new_tokens=None,
Expand Down