Skip to content

Commit

Permalink
Unify input/output types (#295)
Browse files Browse the repository at this point in the history
  • Loading branch information
mrwyattii authored Nov 21, 2023
1 parent 2d49bc5 commit f34b772
Show file tree
Hide file tree
Showing 12 changed files with 442 additions and 876 deletions.
4 changes: 4 additions & 0 deletions .github/workflows/formatting.yml
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,10 @@ jobs:
which python
python --version
- name: Install DeepSpeed
run: |
pip install git+https://github.com/microsoft/DeepSpeed.git
- name: Install MII
run: |
pip install .[dev]
Expand Down
20 changes: 17 additions & 3 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -116,10 +116,17 @@ A non-persistent pipeline is a great way to try DeepSpeed-MII. Non-persistent pi
```python
import mii
pipe = mii.pipeline("mistralai/Mistral-7B-v0.1")
response = pipe("DeepSpeed is", max_new_tokens=128)
response = pipe(["DeepSpeed is", "Seattle is"], max_new_tokens=128)
print(response)
```

The returned `response` is a list of `Response` objects. We can access several details about the generation (e.g., `response[0].prompt_length`):

- `generated_text: str` Text generated by the model.
- `prompt_length: int` Number of tokens in the original prompt.
- `generated_length: int` Number of tokens generated.
- `finish_reason: str` Reason for stopping generation. `stop` indicates the EOS token was generated and `length` indicates the generation reached `max_new_tokens` or `max_length`.

### Tensor parallelism

Taking advantage of multi-GPU systems for greater performance is easy with MII. When run with the `deepspeed` launcher, tensor parallelism is automatically controlled by the `--num_gpus` flag:
Expand Down Expand Up @@ -158,10 +165,17 @@ A persistent deployment is ideal for use with long-running and production applic
```python
import mii
client = mii.serve("mistralai/Mistral-7B-v0.1")
response = client.generate("Deepspeed is", max_new_tokens=128)
print(response.response)
response = client.generate(["Deepspeed is", "Seattle is"], max_new_tokens=128)
print(response)
```

The returned `response` is a list of `Response` objects. We can access several details about the generation (e.g., `response[0].prompt_length`):

- `generated_text: str` Text generated by the model.
- `prompt_length: int` Number of tokens in the original prompt.
- `generated_length: int` Number of tokens generated.
- `finish_reason: str` Reason for stopping generation. `stop` indicates the EOS token was generated and `length` indicates the generation reached `max_new_tokens` or `max_length`.

If we want to generate text from other processes, we can do that too:

```python
Expand Down
32 changes: 16 additions & 16 deletions mii/backend/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
import requests
from typing import Dict, Any, Callable, List, Union

from mii.batching.data_classes import Response
from mii.config import MIIConfig
from mii.constants import GRPC_MAX_MSG_SIZE
from mii.grpc_related.proto import modelresponse_pb2, modelresponse_pb2_grpc
Expand Down Expand Up @@ -37,18 +38,18 @@ def __init__(self, mii_config: MIIConfig, host: str = "localhost") -> None:
channel = create_channel(host, self.port)
self.stub = modelresponse_pb2_grpc.ModelResponseStub(channel)

def __call__(self, *args, **kwargs):
def __call__(self, *args, **kwargs) -> List[Response]:
return self.generate(*args, **kwargs)

async def _request_async_response(self, request_dict, **query_kwargs):
async def _request_async_response(self, prompts, **query_kwargs):
task_methods = TASK_METHODS_DICT[self.task]
proto_request = task_methods.pack_request_to_proto(request_dict, **query_kwargs)
proto_request = task_methods.pack_request_to_proto(prompts, **query_kwargs)
proto_response = await getattr(self.stub, task_methods.method)(proto_request)
return task_methods.unpack_response_from_proto(proto_response)

async def _request_async_response_stream(self, request_dict, **query_kwargs):
async def _request_async_response_stream(self, prompts, **query_kwargs):
task_methods = TASK_METHODS_DICT[self.task]
proto_request = task_methods.pack_request_to_proto(request_dict, **query_kwargs)
proto_request = task_methods.pack_request_to_proto(prompts, **query_kwargs)
assert hasattr(task_methods, "method_stream_out"), f"{self.task} does not support streaming response"
async for response in getattr(self.stub,
task_methods.method_stream_out)(proto_request):
Expand All @@ -59,30 +60,29 @@ def generate(self,
List[str]],
streaming_fn: Callable = None,
**query_kwargs: Dict[str,
Any]):
Any]) -> Union[None,
List[Response]]:
if isinstance(prompts, str):
prompts = [prompts]
if streaming_fn is not None:
if len(prompts) > 1:
raise RuntimeError(
"MII client streaming only supports a single prompt input.")
request_dict = {"query": prompts}
return self._generate_stream(streaming_fn, request_dict, **query_kwargs)
query_kwargs["stream"] = True
return self._generate_stream(streaming_fn, prompts, **query_kwargs)

request_dict = {"query": prompts}
return self.asyncio_loop.run_until_complete(
self._request_async_response(request_dict,
self._request_async_response(prompts,
**query_kwargs))

def _generate_stream(self,
callback,
request_dict: Dict[str,
str],
prompts: List[str],
**query_kwargs: Dict[str,
Any]):
Any]) -> None:
async def put_result():
response_stream = self._request_async_response_stream(
request_dict,
prompts,
**query_kwargs)

while True:
Expand All @@ -94,11 +94,11 @@ async def put_result():

self.asyncio_loop.run_until_complete(put_result())

async def terminate_async(self):
async def terminate_async(self) -> None:
await self.stub.Terminate(
modelresponse_pb2.google_dot_protobuf_dot_empty__pb2.Empty())

def terminate_server(self):
def terminate_server(self) -> None:
self.asyncio_loop.run_until_complete(self.terminate_async())
if self.mii_config.enable_restful_api:
requests.get(
Expand Down
238 changes: 238 additions & 0 deletions mii/batching/data_classes.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,238 @@
# Copyright (c) Microsoft Corporation.
# SPDX-License-Identifier: Apache-2.0

# DeepSpeed Team
from dataclasses import dataclass, field, asdict
from typing import Any, Dict, List, Iterator, Union
from typing_extensions import Self

import torch

from mii.constants import GenerationFinishReason


@dataclass
class Response:
generated_text: str
prompt_length: int
generated_length: int
finish_reason: GenerationFinishReason

@staticmethod
def from_msg_dict(msg: Dict[str, Union[str, int]]) -> Self:
return Response(**msg)

def to_msg_dict(self) -> Dict[str, Union[str, int]]:
return asdict(self)

def __repr__(self) -> str:
return self.generated_text

def __str__(self) -> str:
return self.generated_text


@dataclass
class RequestMsg:
uid: int
input_tokens: Union[torch.Tensor, List[int]]

@property
def is_flush_request(self):
return self.input_tokens is None

@staticmethod
def from_msg_dict(msg: Dict[str, Any]) -> Self:
input_tokens = msg["input_tokens"]
if input_tokens is not None:
input_tokens = torch.tensor(msg["input_tokens"],
dtype=torch.int32,
device=torch.device("cpu"))
return RequestMsg(uid=msg["uid"], input_tokens=input_tokens)


@dataclass
class Request:
tid: int
uid: int
input_tokens: torch.Tensor
prompt_tokens: torch.Tensor
seq_length: int
max_length: int
max_new_tokens: int
min_new_tokens: int
last_in_prompt: bool
post_processing: List[object]
stream: bool = False
ignore_eos: bool = False
return_full_text: bool = False

_next_token: Union[None, torch.Tensor] = None
_is_done: bool = False
_generated_tokens: List[torch.Tensor] = field(default_factory=list)
_finish_reason: GenerationFinishReason = GenerationFinishReason.NONE

@property
def prompt_length(self) -> int:
return len(self.prompt_tokens)

@property
def next_token(self) -> Union[None, torch.Tensor]:
return self._next_token

@next_token.setter
def next_token(self, next_token: Union[None, torch.Tensor]) -> None:
self._next_token = next_token

@property
def is_done(self) -> bool:
if self.ignore_eos:
return False
if self.seq_length < self.min_new_tokens:
return False
return self._is_done

@is_done.setter
def is_done(self, is_done: bool) -> None:
self._is_done = is_done

@property
def generated_tokens(self) -> List[torch.Tensor]:
return self._generated_tokens

@property
def finish_reason(self) -> GenerationFinishReason:
return self._finish_reason

@property
def is_flush_request(self):
return self.input_tokens is None

@property
def num_generated_tokens(self) -> int:
# We return zero while we are processing decomposed prompts
return self.seq_length - self.prompt_length + 1 if self.seq_length >= self.prompt_length else 0

@property
def stop_generation(self) -> bool:
# Returns whether to stop generation for request
if self.is_done:
self._finish_reason = GenerationFinishReason.STOP
return True
if (self.seq_length >= self.max_length) or (self.num_generated_tokens >=
self.max_new_tokens):
self._finish_reason = GenerationFinishReason.LENGTH
return True
return False

def to_msg_dict(self) -> Dict[str, Any]:
# Returns a minimal version of the request of purposes of broadcasting to all ranks
input_tokens = self.input_tokens
if input_tokens is not None:
input_tokens = self.input_tokens.tolist()
return {"uid": self.uid, "input_tokens": input_tokens}

def accumulate_generated_token(self) -> None:
# Append the latest token to the list of generated tokens
if not self.is_done:
self._generated_tokens.append(self.next_token)

def clear_generated_token(self) -> None:
self._generated_tokens.clear()

def set_next_as_input(self) -> None:
# Places the next token into the input token for next round of generation
if self.next_token is not None:
self.input_tokens = self.next_token.unsqueeze(0)
self.last_in_prompt = True
self.next_token = None
self.is_done = False


class RequestBatch:
def __init__(self, requests: List[Request] = None) -> None:
if requests is None:
requests = []
self.requests = requests

def __len__(self) -> int:
return len(self.requests)

def __contains__(self, r: Request) -> bool:
return r in self.requests

def __nonzero__(self) -> bool:
if len(self.requests) != 0:
return True
return False

def __iter__(self) -> Iterator[Request]:
return iter(self.requests)

def __repr__(self) -> str:
return f"RequestBatch({self.requests})"

@property
def requests_to_run(self) -> Self:
return RequestBatch([r for r in self.requests if not r.is_flush_request])

@property
def requests_to_flush(self) -> Self:
return RequestBatch([r for r in self.requests if r.is_flush_request])

@property
def last_in_prompt(self) -> Self:
return RequestBatch([r for r in self.requests if r.last_in_prompt])

@property
def completed(self) -> Self:
return RequestBatch([r for r in self.requests if r.stop_generation])

@property
def uids(self) -> List[int]:
return [r.uid for r in self.requests]

@property
def lengths(self) -> List[int]:
return [len(r.input_tokens) for r in self.requests]

@property
def tokens(self) -> List[torch.Tensor]:
return [r.input_tokens for r in self.requests]

@property
def next_tokens(self) -> List[torch.Tensor]:
return [r.next_token for r in self.requests]

@property
def done_tokens(self) -> List[torch.Tensor]:
return [r.is_done for r in self.requests]

@next_tokens.setter
def next_tokens(self, next_tokens: List[torch.Tensor]) -> None:
assert len(next_tokens) == len(self.requests)
for idx, r in enumerate(self.requests):
r.next_token = next_tokens[idx]

@done_tokens.setter
def done_tokens(self, done_tokens: List[torch.Tensor]) -> None:
assert len(done_tokens) == len(self.requests)
for idx, r in enumerate(self.requests):
r.is_done = done_tokens[idx]

def to_msg_dicts(self) -> List[Dict[str, Any]]:
return [r.to_msg_dict() for r in self.requests]

@staticmethod
def from_msg_dicts(msg_dicts: List[Dict[str, Any]]) -> Self:
return RequestBatch([RequestMsg.from_msg_dict(msg) for msg in msg_dicts])

def prune(self, uids: List[int]) -> None:
self.requests = [r for r in self.requests if r.uid not in uids]

def append(self, r: Request) -> None:
self.requests.append(r)

def update_seq_length(self) -> None:
for r in self.requests:
r.seq_length += r.input_tokens.size(0)
Loading

0 comments on commit f34b772

Please sign in to comment.