From f34b772429f77067736d98986f650eeb34263ca4 Mon Sep 17 00:00:00 2001 From: Michael Wyatt Date: Tue, 21 Nov 2023 12:52:45 -0800 Subject: [PATCH] Unify input/output types (#295) --- .github/workflows/formatting.yml | 4 + README.md | 20 +- mii/backend/client.py | 32 +- mii/batching/data_classes.py | 238 +++++++++++ mii/batching/ragged_batching.py | 309 ++------------- mii/grpc_related/modelresponse_server.py | 63 ++- mii/grpc_related/proto/modelresponse.proto | 80 +--- mii/grpc_related/proto/modelresponse_pb2.py | 53 +-- .../proto/modelresponse_pb2_grpc.py | 370 +----------------- mii/grpc_related/restful_gateway.py | 6 +- mii/grpc_related/task_methods.py | 105 ++--- tests/test_deployment.py | 38 +- 12 files changed, 442 insertions(+), 876 deletions(-) create mode 100644 mii/batching/data_classes.py diff --git a/.github/workflows/formatting.yml b/.github/workflows/formatting.yml index 6aa26187..82cd1b9f 100644 --- a/.github/workflows/formatting.yml +++ b/.github/workflows/formatting.yml @@ -27,6 +27,10 @@ jobs: which python python --version + - name: Install DeepSpeed + run: | + pip install git+https://github.com/microsoft/DeepSpeed.git + - name: Install MII run: | pip install .[dev] diff --git a/README.md b/README.md index b1e09bea..0d76590f 100644 --- a/README.md +++ b/README.md @@ -116,10 +116,17 @@ A non-persistent pipeline is a great way to try DeepSpeed-MII. Non-persistent pi ```python import mii pipe = mii.pipeline("mistralai/Mistral-7B-v0.1") -response = pipe("DeepSpeed is", max_new_tokens=128) +response = pipe(["DeepSpeed is", "Seattle is"], max_new_tokens=128) print(response) ``` +The returned `response` is a list of `Response` objects. We can access several details about the generation (e.g., `response[0].prompt_length`): + +- `generated_text: str` Text generated by the model. +- `prompt_length: int` Number of tokens in the original prompt. +- `generated_length: int` Number of tokens generated. +- `finish_reason: str` Reason for stopping generation. `stop` indicates the EOS token was generated and `length` indicates the generation reached `max_new_tokens` or `max_length`. + ### Tensor parallelism Taking advantage of multi-GPU systems for greater performance is easy with MII. When run with the `deepspeed` launcher, tensor parallelism is automatically controlled by the `--num_gpus` flag: @@ -158,10 +165,17 @@ A persistent deployment is ideal for use with long-running and production applic ```python import mii client = mii.serve("mistralai/Mistral-7B-v0.1") -response = client.generate("Deepspeed is", max_new_tokens=128) -print(response.response) +response = client.generate(["Deepspeed is", "Seattle is"], max_new_tokens=128) +print(response) ``` +The returned `response` is a list of `Response` objects. We can access several details about the generation (e.g., `response[0].prompt_length`): + +- `generated_text: str` Text generated by the model. +- `prompt_length: int` Number of tokens in the original prompt. +- `generated_length: int` Number of tokens generated. +- `finish_reason: str` Reason for stopping generation. `stop` indicates the EOS token was generated and `length` indicates the generation reached `max_new_tokens` or `max_length`. + If we want to generate text from other processes, we can do that too: ```python diff --git a/mii/backend/client.py b/mii/backend/client.py index b2cd8118..796324b3 100644 --- a/mii/backend/client.py +++ b/mii/backend/client.py @@ -7,6 +7,7 @@ import requests from typing import Dict, Any, Callable, List, Union +from mii.batching.data_classes import Response from mii.config import MIIConfig from mii.constants import GRPC_MAX_MSG_SIZE from mii.grpc_related.proto import modelresponse_pb2, modelresponse_pb2_grpc @@ -37,18 +38,18 @@ def __init__(self, mii_config: MIIConfig, host: str = "localhost") -> None: channel = create_channel(host, self.port) self.stub = modelresponse_pb2_grpc.ModelResponseStub(channel) - def __call__(self, *args, **kwargs): + def __call__(self, *args, **kwargs) -> List[Response]: return self.generate(*args, **kwargs) - async def _request_async_response(self, request_dict, **query_kwargs): + async def _request_async_response(self, prompts, **query_kwargs): task_methods = TASK_METHODS_DICT[self.task] - proto_request = task_methods.pack_request_to_proto(request_dict, **query_kwargs) + proto_request = task_methods.pack_request_to_proto(prompts, **query_kwargs) proto_response = await getattr(self.stub, task_methods.method)(proto_request) return task_methods.unpack_response_from_proto(proto_response) - async def _request_async_response_stream(self, request_dict, **query_kwargs): + async def _request_async_response_stream(self, prompts, **query_kwargs): task_methods = TASK_METHODS_DICT[self.task] - proto_request = task_methods.pack_request_to_proto(request_dict, **query_kwargs) + proto_request = task_methods.pack_request_to_proto(prompts, **query_kwargs) assert hasattr(task_methods, "method_stream_out"), f"{self.task} does not support streaming response" async for response in getattr(self.stub, task_methods.method_stream_out)(proto_request): @@ -59,30 +60,29 @@ def generate(self, List[str]], streaming_fn: Callable = None, **query_kwargs: Dict[str, - Any]): + Any]) -> Union[None, + List[Response]]: if isinstance(prompts, str): prompts = [prompts] if streaming_fn is not None: if len(prompts) > 1: raise RuntimeError( "MII client streaming only supports a single prompt input.") - request_dict = {"query": prompts} - return self._generate_stream(streaming_fn, request_dict, **query_kwargs) + query_kwargs["stream"] = True + return self._generate_stream(streaming_fn, prompts, **query_kwargs) - request_dict = {"query": prompts} return self.asyncio_loop.run_until_complete( - self._request_async_response(request_dict, + self._request_async_response(prompts, **query_kwargs)) def _generate_stream(self, callback, - request_dict: Dict[str, - str], + prompts: List[str], **query_kwargs: Dict[str, - Any]): + Any]) -> None: async def put_result(): response_stream = self._request_async_response_stream( - request_dict, + prompts, **query_kwargs) while True: @@ -94,11 +94,11 @@ async def put_result(): self.asyncio_loop.run_until_complete(put_result()) - async def terminate_async(self): + async def terminate_async(self) -> None: await self.stub.Terminate( modelresponse_pb2.google_dot_protobuf_dot_empty__pb2.Empty()) - def terminate_server(self): + def terminate_server(self) -> None: self.asyncio_loop.run_until_complete(self.terminate_async()) if self.mii_config.enable_restful_api: requests.get( diff --git a/mii/batching/data_classes.py b/mii/batching/data_classes.py new file mode 100644 index 00000000..4bc46f73 --- /dev/null +++ b/mii/batching/data_classes.py @@ -0,0 +1,238 @@ +# Copyright (c) Microsoft Corporation. +# SPDX-License-Identifier: Apache-2.0 + +# DeepSpeed Team +from dataclasses import dataclass, field, asdict +from typing import Any, Dict, List, Iterator, Union +from typing_extensions import Self + +import torch + +from mii.constants import GenerationFinishReason + + +@dataclass +class Response: + generated_text: str + prompt_length: int + generated_length: int + finish_reason: GenerationFinishReason + + @staticmethod + def from_msg_dict(msg: Dict[str, Union[str, int]]) -> Self: + return Response(**msg) + + def to_msg_dict(self) -> Dict[str, Union[str, int]]: + return asdict(self) + + def __repr__(self) -> str: + return self.generated_text + + def __str__(self) -> str: + return self.generated_text + + +@dataclass +class RequestMsg: + uid: int + input_tokens: Union[torch.Tensor, List[int]] + + @property + def is_flush_request(self): + return self.input_tokens is None + + @staticmethod + def from_msg_dict(msg: Dict[str, Any]) -> Self: + input_tokens = msg["input_tokens"] + if input_tokens is not None: + input_tokens = torch.tensor(msg["input_tokens"], + dtype=torch.int32, + device=torch.device("cpu")) + return RequestMsg(uid=msg["uid"], input_tokens=input_tokens) + + +@dataclass +class Request: + tid: int + uid: int + input_tokens: torch.Tensor + prompt_tokens: torch.Tensor + seq_length: int + max_length: int + max_new_tokens: int + min_new_tokens: int + last_in_prompt: bool + post_processing: List[object] + stream: bool = False + ignore_eos: bool = False + return_full_text: bool = False + + _next_token: Union[None, torch.Tensor] = None + _is_done: bool = False + _generated_tokens: List[torch.Tensor] = field(default_factory=list) + _finish_reason: GenerationFinishReason = GenerationFinishReason.NONE + + @property + def prompt_length(self) -> int: + return len(self.prompt_tokens) + + @property + def next_token(self) -> Union[None, torch.Tensor]: + return self._next_token + + @next_token.setter + def next_token(self, next_token: Union[None, torch.Tensor]) -> None: + self._next_token = next_token + + @property + def is_done(self) -> bool: + if self.ignore_eos: + return False + if self.seq_length < self.min_new_tokens: + return False + return self._is_done + + @is_done.setter + def is_done(self, is_done: bool) -> None: + self._is_done = is_done + + @property + def generated_tokens(self) -> List[torch.Tensor]: + return self._generated_tokens + + @property + def finish_reason(self) -> GenerationFinishReason: + return self._finish_reason + + @property + def is_flush_request(self): + return self.input_tokens is None + + @property + def num_generated_tokens(self) -> int: + # We return zero while we are processing decomposed prompts + return self.seq_length - self.prompt_length + 1 if self.seq_length >= self.prompt_length else 0 + + @property + def stop_generation(self) -> bool: + # Returns whether to stop generation for request + if self.is_done: + self._finish_reason = GenerationFinishReason.STOP + return True + if (self.seq_length >= self.max_length) or (self.num_generated_tokens >= + self.max_new_tokens): + self._finish_reason = GenerationFinishReason.LENGTH + return True + return False + + def to_msg_dict(self) -> Dict[str, Any]: + # Returns a minimal version of the request of purposes of broadcasting to all ranks + input_tokens = self.input_tokens + if input_tokens is not None: + input_tokens = self.input_tokens.tolist() + return {"uid": self.uid, "input_tokens": input_tokens} + + def accumulate_generated_token(self) -> None: + # Append the latest token to the list of generated tokens + if not self.is_done: + self._generated_tokens.append(self.next_token) + + def clear_generated_token(self) -> None: + self._generated_tokens.clear() + + def set_next_as_input(self) -> None: + # Places the next token into the input token for next round of generation + if self.next_token is not None: + self.input_tokens = self.next_token.unsqueeze(0) + self.last_in_prompt = True + self.next_token = None + self.is_done = False + + +class RequestBatch: + def __init__(self, requests: List[Request] = None) -> None: + if requests is None: + requests = [] + self.requests = requests + + def __len__(self) -> int: + return len(self.requests) + + def __contains__(self, r: Request) -> bool: + return r in self.requests + + def __nonzero__(self) -> bool: + if len(self.requests) != 0: + return True + return False + + def __iter__(self) -> Iterator[Request]: + return iter(self.requests) + + def __repr__(self) -> str: + return f"RequestBatch({self.requests})" + + @property + def requests_to_run(self) -> Self: + return RequestBatch([r for r in self.requests if not r.is_flush_request]) + + @property + def requests_to_flush(self) -> Self: + return RequestBatch([r for r in self.requests if r.is_flush_request]) + + @property + def last_in_prompt(self) -> Self: + return RequestBatch([r for r in self.requests if r.last_in_prompt]) + + @property + def completed(self) -> Self: + return RequestBatch([r for r in self.requests if r.stop_generation]) + + @property + def uids(self) -> List[int]: + return [r.uid for r in self.requests] + + @property + def lengths(self) -> List[int]: + return [len(r.input_tokens) for r in self.requests] + + @property + def tokens(self) -> List[torch.Tensor]: + return [r.input_tokens for r in self.requests] + + @property + def next_tokens(self) -> List[torch.Tensor]: + return [r.next_token for r in self.requests] + + @property + def done_tokens(self) -> List[torch.Tensor]: + return [r.is_done for r in self.requests] + + @next_tokens.setter + def next_tokens(self, next_tokens: List[torch.Tensor]) -> None: + assert len(next_tokens) == len(self.requests) + for idx, r in enumerate(self.requests): + r.next_token = next_tokens[idx] + + @done_tokens.setter + def done_tokens(self, done_tokens: List[torch.Tensor]) -> None: + assert len(done_tokens) == len(self.requests) + for idx, r in enumerate(self.requests): + r.is_done = done_tokens[idx] + + def to_msg_dicts(self) -> List[Dict[str, Any]]: + return [r.to_msg_dict() for r in self.requests] + + @staticmethod + def from_msg_dicts(msg_dicts: List[Dict[str, Any]]) -> Self: + return RequestBatch([RequestMsg.from_msg_dict(msg) for msg in msg_dicts]) + + def prune(self, uids: List[int]) -> None: + self.requests = [r for r in self.requests if r.uid not in uids] + + def append(self, r: Request) -> None: + self.requests.append(r) + + def update_seq_length(self) -> None: + for r in self.requests: + r.seq_length += r.input_tokens.size(0) diff --git a/mii/batching/ragged_batching.py b/mii/batching/ragged_batching.py index a74700b3..8a5fe01d 100644 --- a/mii/batching/ragged_batching.py +++ b/mii/batching/ragged_batching.py @@ -9,10 +9,8 @@ import threading import time from collections import deque, defaultdict -from dataclasses import dataclass, asdict, field from functools import cached_property -from typing import Dict, Tuple, List, Any, Iterator, Union, DefaultDict -from typing_extensions import Self +from typing import Dict, Tuple, List, Any, Union, DefaultDict import torch import ujson @@ -42,6 +40,7 @@ TEMP_NAME, SAMPLER_NAME, STOP_NAME) +from mii.batching.data_classes import Response, Request, RequestBatch from mii.batching.generation.logit_processors import TopPLogitProcessor, TopKLogitProcessor, TemperatureLogitProcessor from mii.batching.generation.samplers import LogitsSampler, GreedySampler from mii.batching.generation.stop_criterion import EosGenerationStopCriterion, TokenStopCriterion @@ -55,262 +54,6 @@ from mii.logging import logger -@dataclass -class Response: - generated_text: str - prompt_length: int - generated_length: int - finish_reason: GenerationFinishReason - - @staticmethod - def from_msg(msg: Dict[str, Union[str, int]]) -> Self: - return Response( - generated_text=msg["generated_text"], - prompt_length=msg["prompt_length"], - generated_length=msg["generated_length"], - finish_reason=GenerationFinishReason(msg["finish_reason"]), - ) - - def get_msg(self) -> Dict[str, Union[str, int]]: - return { - "generated_text": self.generated_text, - "prompt_length": self.prompt_length, - "generated_length": self.generated_length, - "finish_reason": self.finish_reason.value - } - - def __repr__(self) -> str: - return self.generated_text - - def __str__(self) -> str: - return self.generated_text - - -class ResponseBatch: - def __init__(self, responses: List[Response]) -> None: - self.responses = responses - - def __iter__(self) -> Iterator[Response]: - return iter(self.responses) - - def __repr__(self) -> str: - return "\n\n".join(str(r) for r in self.responses) - - @property - def generated_texts(self) -> List[str]: - return [r.generated_text for r in self.responses] - - @property - def prompt_lengths(self) -> List[int]: - return [r.prompt_length for r in self.responses] - - @property - def generated_lengths(self) -> List[int]: - return [r.generated_length for r in self.responses] - - @property - def finish_reasons(self) -> List[GenerationFinishReason]: - return [r.finish_reason for r in self.responses] - - def append(self, response: Response) -> None: - self.responses.append(response) - - -@dataclass -class RaggedRequestMsg: - uid: int - input_tokens: Union[torch.Tensor, List[int]] - - @property - def is_flush_request(self): - return self.input_tokens is None - - @staticmethod - def from_msg(msg: Dict[str, int]) -> Self: - return RaggedRequestMsg( - uid=msg["uid"], - input_tokens=None - if msg["input_tokens"] is None else torch.tensor(msg["input_tokens"], - dtype=torch.int32, - device=torch.device("cpu")), - ) - - -@dataclass -class RaggedRequest: - tid: int - uid: int - input_tokens: torch.Tensor - prompt_tokens: torch.Tensor - seq_length: int - max_length: int - max_new_tokens: int - min_new_tokens: int - last_in_prompt: bool - post_processing: List[object] - stream: bool = False - ignore_eos: bool = False - return_full_text: bool = False - - _next_token: Union[None, torch.Tensor] = None - _is_done: bool = False - _generated_tokens: List[torch.Tensor] = field(default_factory=list) - _finish_reason: GenerationFinishReason = GenerationFinishReason.NONE - - @property - def prompt_length(self) -> int: - return len(self.prompt_tokens) - - @property - def next_token(self) -> Union[None, torch.Tensor]: - return self._next_token - - @next_token.setter - def next_token(self, next_token: Union[None, torch.Tensor]) -> None: - self._next_token = next_token - - @property - def is_done(self) -> bool: - if self.ignore_eos: - return False - if self.seq_length < self.min_new_tokens: - return False - return self._is_done - - @is_done.setter - def is_done(self, is_done: bool) -> None: - self._is_done = is_done - - @property - def generated_tokens(self) -> List[torch.Tensor]: - return self._generated_tokens - - @property - def finish_reason(self) -> GenerationFinishReason: - return self._finish_reason - - @property - def is_flush_request(self): - return self.input_tokens is None - - @property - def num_generated_tokens(self) -> int: - # We return zero while we are processing decomposed prompts - return self.seq_length - self.prompt_length + 1 if self.seq_length >= self.prompt_length else 0 - - @property - def stop_generation(self) -> bool: - if self.is_done: - self._finish_reason = GenerationFinishReason.STOP - return True - if (self.seq_length >= self.max_length) or (self.num_generated_tokens >= - self.max_new_tokens): - self._finish_reason = GenerationFinishReason.LENGTH - return True - return False - - def get_msg(self) -> RaggedRequestMsg: - return RaggedRequestMsg( - uid=self.uid, - input_tokens=None - if self.input_tokens is None else self.input_tokens.tolist(), - ) - - def accumulate_generated_token(self) -> None: - if not self.is_done: - self._generated_tokens.append(self.next_token) - - def clear_generated_token(self) -> None: - self._generated_tokens.clear() - - def set_next_as_input(self) -> None: - if self.next_token is not None: - self.input_tokens = self.next_token.unsqueeze(0) - self.last_in_prompt = True - self.next_token = None - self.is_done = False - - -class RaggedRequestBatch: - def __init__(self, requests: List[RaggedRequest]) -> None: - self.requests = requests - - def __len__(self) -> int: - return len(self.requests) - - def __contains__(self, r: RaggedRequest) -> bool: - return r in self.requests - - def __nonzero__(self) -> bool: - if len(self.requests) != 0: - return True - return False - - def __iter__(self) -> Iterator[RaggedRequest]: - return iter(self.requests) - - def __repr__(self) -> str: - return f"RaggedRequestBatch({self.requests})" - - @property - def requests_to_run(self) -> Self: - return RaggedRequestBatch([r for r in self.requests if not r.is_flush_request]) - - @property - def requests_to_flush(self) -> Self: - return RaggedRequestBatch([r for r in self.requests if r.is_flush_request]) - - @property - def last_in_prompt(self) -> Self: - return RaggedRequestBatch([r for r in self.requests if r.last_in_prompt]) - - @property - def completed(self) -> Self: - return RaggedRequestBatch([r for r in self.requests if r.stop_generation]) - - @property - def uids(self) -> List[int]: - return [r.uid for r in self.requests] - - @property - def lengths(self) -> List[int]: - return [len(r.input_tokens) for r in self.requests] - - @property - def tokens(self) -> List[torch.Tensor]: - return [r.input_tokens for r in self.requests] - - @property - def next_tokens(self) -> List[torch.Tensor]: - return [r.next_token for r in self.requests] - - @property - def done_tokens(self) -> List[torch.Tensor]: - return [r.is_done for r in self.requests] - - @next_tokens.setter - def next_tokens(self, next_tokens: List[torch.Tensor]) -> None: - assert len(next_tokens) == len(self.requests) - for idx, r in enumerate(self.requests): - r.next_token = next_tokens[idx] - - @done_tokens.setter - def done_tokens(self, done_tokens: List[torch.Tensor]) -> None: - assert len(done_tokens) == len(self.requests) - for idx, r in enumerate(self.requests): - r.is_done = done_tokens[idx] - - def prune(self, uids: List[int]) -> None: - self.requests = [r for r in self.requests if r.uid not in uids] - - def append(self, r: RaggedRequest) -> None: - self.requests.append(r) - - def update_seq_length(self) -> None: - for r in self.requests: - r.seq_length += r.input_tokens.size(0) - - class RaggedBatchBase: def __init__(self, inference_engine, tokenizer, model_config): self.inference_engine = inference_engine @@ -327,7 +70,7 @@ def __init__(self, inference_engine, tokenizer, model_config): self.request_queue: queue.Queue = queue.Queue() self.result_queues: Dict[int, queue.Queue] = {} - self.scheduled_requests: RaggedRequestBatch = RaggedRequestBatch([]) + self.scheduled_requests: RequestBatch = RequestBatch() self.buffer = deque() self.scheduled_length = 0 self.scheduled_seq_num = 0 @@ -429,27 +172,26 @@ def _print_profiled_times(self) -> None: self._num_generated_tokens = 0 @sync_debug - def _bcast_requests(self, force=False) -> RaggedRequestBatch: + def _bcast_requests(self, force=False) -> RequestBatch: if self.is_rank_0: if not self.scheduled_requests and not force: return self.scheduled_requests # Rank 0 gets batch of requests and broadcasts to other ranks - data_dicts = [asdict(r.get_msg()) for r in self.scheduled_requests] + data_dicts = self.scheduled_requests.to_msg_dicts() json_data = ujson.dumps(data_dicts) self.socket.send_string(json_data) else: try: json_data = self.socket.recv_string() data_dicts = ujson.loads(json_data) - self.scheduled_requests = RaggedRequestBatch( - [RaggedRequestMsg.from_msg(msg) for msg in data_dicts]) + self.scheduled_requests = RequestBatch.from_msg_dicts(data_dicts) except zmq.Again: - self.scheduled_requests = RaggedRequestBatch([]) + self.scheduled_requests = RequestBatch() return self.scheduled_requests def _reset_scheduler_bookkeeping(self) -> None: - self.scheduled_requests = RaggedRequestBatch([]) + self.scheduled_requests = RequestBatch() self.scheduled_length = 0 self.scheduled_seq_num = 0 self.scheduled_req_blocks.zero_() @@ -458,8 +200,8 @@ def _reset_scheduler_bookkeeping(self) -> None: def _process_logits( self, next_token_logits: torch.Tensor, - running_requests: RaggedRequestBatch) -> Tuple[torch.Tensor, - torch.Tensor]: + running_requests: RequestBatch) -> Tuple[torch.Tensor, + torch.Tensor]: next_token_logits = next_token_logits[:, :self.vocab_size] next_token_logits = self.logit_processor(next_token_logits, running_requests, @@ -474,7 +216,7 @@ def _process_logits( return next_tokens, done_tokens @sync_debug - def _generate_output(self, r: RaggedRequest) -> bool: + def _generate_output(self, r: Request) -> bool: outputs = [] if r.stream: outputs.append(( @@ -503,7 +245,7 @@ def _generate_output(self, r: RaggedRequest) -> bool: for output in outputs: self.result_queues[r.tid].put_nowait(output) - def _do_schedule_requests(self, requests: List[RaggedRequest]) -> None: + def _do_schedule_requests(self, requests: List[Request]) -> None: free_blocks = self.inference_engine.free_blocks conf_manager = self.inference_engine._config.state_manager @@ -583,7 +325,7 @@ def schedule_requests(self) -> None: print( "Deadlock detected. Resetting KV cache and recomputing requests. Consider limiting number of concurrent requests or decreasing max lengths of prompts/generations." ) - self.scheduled_requests = RaggedRequestBatch([]) + self.scheduled_requests = RequestBatch() self.reset_request_status() else: scheduled_requests_ids = set(id(r) for r in self.scheduled_requests) @@ -592,7 +334,7 @@ def schedule_requests(self) -> None: def _queue_flush_request(self, uid: int) -> None: self.request_queue.put_nowait( - RaggedRequest( + Request( tid=None, uid=uid, input_tokens=None, @@ -627,7 +369,7 @@ def make_request(self, tid: int, uid: int, input_tokens: torch.Tensor, - kwargs: Dict) -> RaggedRequest: + kwargs: Dict) -> Request: prompt_length = len(input_tokens) max_length = kwargs.pop(MAX_LENGTH_KWARG, self.max_length) assert max_length > prompt_length, f"prompt length must be less than {MAX_LENGTH_KWARG}" @@ -687,7 +429,7 @@ def make_request(self, assert kwargs == {}, f"Unknown keyword arguments {kwargs}" - return RaggedRequest( + return Request( tid=tid, uid=uid, input_tokens=input_tokens, @@ -726,10 +468,10 @@ def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) self.tid = threading.get_ident() - def __call__(self, inputs: Union[str, List[str]], **kwargs) -> ResponseBatch: + def __call__(self, inputs: Union[str, List[str]], **kwargs) -> List[Response]: if isinstance(inputs, str): inputs = [inputs] - outputs: ResponseBatch = ResponseBatch([]) + outputs: List[Response] = [] uids_running: List[int] = list(range(len(inputs))) uids_complete_order: List[int] = [] @@ -757,12 +499,12 @@ def __call__(self, inputs: Union[str, List[str]], **kwargs) -> ResponseBatch: while self.scheduled_requests: self.generate() - outputs = ResponseBatch([ + outputs = [ r for idx, r in sorted(zip(uids_complete_order, outputs), key=lambda pair: pair[0]) - ]) + ] if self.model_config.all_rank_output: outputs = self._bcast_responses(outputs) @@ -782,15 +524,15 @@ def _get_response(self) -> Tuple[int, Response]: response = self.make_response(generated_tokens, result[2], result[3], result[4]) return uid, response - def _bcast_responses(self, responses: ResponseBatch) -> ResponseBatch: + def _bcast_responses(self, responses: List[Response]) -> List[Response]: if self.is_rank_0: - data_dicts = [r.get_msg() for r in responses] + data_dicts = [r.to_msg_dict() for r in responses] json_data = ujson.dumps(data_dicts) self.socket.send_string(json_data) else: json_data = self.socket.recv_string() data_dicts = ujson.loads(json_data) - responses = ResponseBatch([Response.from_msg(msg) for msg in data_dicts]) + responses = [Response.from_msg_dict(msg) for msg in data_dicts] return responses @@ -851,11 +593,6 @@ def put_request(self, prompt: str, kwargs: Dict) -> int: return uid - def is_response_ready(self, uid: int) -> bool: - if not self.is_rank_0: - return True - return not self.result_queues[uid].empty() - def get_response(self) -> Tuple[int, Response]: # TODO: We should avoid any request/response work with non-rank 0, but # this requires some refactoring how we do the put and request in diff --git a/mii/grpc_related/modelresponse_server.py b/mii/grpc_related/modelresponse_server.py index 113a3ae2..69d37890 100644 --- a/mii/grpc_related/modelresponse_server.py +++ b/mii/grpc_related/modelresponse_server.py @@ -3,29 +3,26 @@ # DeepSpeed Team import asyncio +import queue +import sys +import threading from concurrent import futures -import logging +from typing import Dict import grpc - from google.protobuf import empty_pb2 as google_dot_protobuf_dot_empty__pb2 -from .proto import modelresponse_pb2_grpc -import sys -import threading -import time -import queue +from mii.backend.client import create_channel from mii.constants import ( + GenerationFinishReason, GRPC_MAX_MSG_SIZE, TERMINATE_METHOD, LB_MAX_WORKER_THREADS, SERVER_SHUTDOWN_TIMEOUT, STREAM_RESPONSE_QUEUE_TIMEOUT, ) -from mii.grpc_related.task_methods import TASK_METHODS_DICT -from mii.backend.client import create_channel - -from mii.constants import GenerationFinishReason +from mii.grpc_related.proto import modelresponse_pb2_grpc +from mii.grpc_related.task_methods import TASK_METHODS_DICT, TaskMethods class ServiceBase(modelresponse_pb2_grpc.ModelResponseServicer): @@ -53,7 +50,7 @@ def __init__(self, async_pipeline=None): self.method_name_to_task = {m.method: t for t, m in TASK_METHODS_DICT.items()} self.lock = threading.Lock() - def _run_inference(self, method_name, request_proto): + def _get_task_methods(self, method_name: str) -> Dict[str, TaskMethods]: if method_name not in self.method_name_to_task: raise ValueError(f"unknown method: {method_name}") @@ -62,12 +59,14 @@ def _run_inference(self, method_name, request_proto): raise ValueError(f"unknown task: {task}") task_methods = TASK_METHODS_DICT[task] - prompts, kwargs = task_methods.unpack_request_from_proto(request_proto) + return task_methods + + def GeneratorReply(self, request, context): + task_methods = self._get_task_methods("GeneratorReply") + + prompts, kwargs = task_methods.unpack_request_from_proto(request) + uids_running, uids_complete_order, responses = [], [], [] - start = time.time() - uids_running = [] - uids_complete_order = [] - responses = [] # Put requests for all prompts into the pipeline for p in prompts: request_kwargs = kwargs.copy() @@ -85,7 +84,6 @@ def _run_inference(self, method_name, request_proto): self.inference_pipeline.flush_uid(uid) uids_complete_order.append(uids_running.index(uid)) uids_running.remove(uid) - end = time.time() # Sort responses in the order of prompts responses = [ @@ -95,31 +93,19 @@ def _run_inference(self, method_name, request_proto): key=lambda pair: pair[0]) ] - return task_methods.pack_response_to_proto(responses, end - start, -1) - - def GeneratorReply(self, request, context): - return self._run_inference("GeneratorReply", request) - - def _run_inference_stream(self, method_name, request_proto) -> int: - task = self.method_name_to_task[method_name] - task_methods = TASK_METHODS_DICT[task] - prompts, kwargs = task_methods.unpack_request_from_proto(request_proto) - - kwargs["stream"] = True - # NOTE: Streaming handle only single prompt inputs - return self.inference_pipeline.put_request(prompts[0], kwargs) + return task_methods.pack_response_to_proto(responses) def GeneratorReplyStream(self, request, context): - method_name = "GeneratorReply" - task = self.method_name_to_task[method_name] - task_methods = TASK_METHODS_DICT[task] + task_methods = self._get_task_methods("GeneratorReply") + + prompts, kwargs = task_methods.unpack_request_from_proto(request) + uid = self.inference_pipeline.put_request(prompts[0], kwargs) - uid = self._run_inference_stream(method_name, request) while True: response_uid, r = self.inference_pipeline.get_response() assert uid == response_uid, "uid mismatch" done = r.finish_reason != GenerationFinishReason.NONE - response = task_methods.pack_response_to_proto([r], 0.0, 0.0) + response = task_methods.pack_response_to_proto([r]) yield response if done: break @@ -258,8 +244,8 @@ def invoke_intercept_method_stream(request_proto, context): response_proto = result_queue.get( timeout=STREAM_RESPONSE_QUEUE_TIMEOUT) yield response_proto - if response_proto.details[0].finish_reason != str( - GenerationFinishReason.NONE): + if response_proto.response[0].finish_reason != str( + GenerationFinishReason.NONE.value): break except queue.Empty: print( @@ -302,5 +288,6 @@ def serve_load_balancing(model_config, lb_port): if __name__ == "__main__": + import logging logging.basicConfig() serve_inference(None, sys.argv[1]) diff --git a/mii/grpc_related/proto/modelresponse.proto b/mii/grpc_related/proto/modelresponse.proto index c2d0899f..9ea04a9c 100644 --- a/mii/grpc_related/proto/modelresponse.proto +++ b/mii/grpc_related/proto/modelresponse.proto @@ -25,17 +25,8 @@ package modelresponse; service ModelResponse { rpc Terminate (google.protobuf.Empty) returns (google.protobuf.Empty) {} - rpc CreateSession (SessionID) returns (google.protobuf.Empty) {} - rpc DestroySession (SessionID) returns (google.protobuf.Empty) {} - rpc GeneratorReply (MultiStringRequest) returns (MultiStringReply) {} - rpc ClassificationReply (SingleStringRequest) returns (SingleStringReply) {} - rpc QuestionAndAnswerReply(QARequest) returns (SingleStringReply) {} - rpc FillMaskReply(SingleStringRequest) returns (SingleStringReply) {} - rpc TokenClassificationReply(SingleStringRequest) returns (SingleStringReply) {} - rpc ConversationalReply(ConversationRequest) returns (ConversationReply) {} - rpc Txt2ImgReply(MultiStringRequest) returns (ImageReply) {} - - rpc GeneratorReplyStream (MultiStringRequest) returns (stream GenerationReply) {} + rpc GeneratorReply (MultiStringRequest) returns (MultiGenerationReply) {} + rpc GeneratorReplyStream (MultiStringRequest) returns (stream MultiGenerationReply) {} } message Dictionary { @@ -52,10 +43,6 @@ message Value { } } -message SessionID { - string session_id = 1; -} - message SingleStringRequest { string request = 1; map query_kwargs = 2; @@ -66,62 +53,15 @@ message MultiStringRequest { map query_kwargs = 2; } -message SingleStringReply { +message SingleGenerationReply { string response = 1; - float time_taken = 2; - float model_time_taken = 3; -} - -message MultiStringReply { - repeated string response = 1; - float time_taken = 2; - float model_time_taken = 3; -} - -message GenerationDetails { - string finish_reason = 1; - int64 prompt_tokens = 2; - int64 generated_tokens = 3; -} - -message GenerationReply { - repeated string response = 1; - // A request may contain multiple prompts and they produce different number of tokens. - // When streaming output is enabled, a response may contain generated tokens only for some prompts. - // `indices` represents the indices of prompts for which `response` and `details` are provided. - repeated int64 indices = 2; - repeated GenerationDetails details = 3; - float time_taken = 4; - float model_time_taken = 5; -} - -message QARequest { - string question = 1; - string context = 2; - map query_kwargs = 3; -} - -message ConversationRequest { - string text = 1; - string conversation_id = 2; - repeated string past_user_inputs = 3; - repeated string generated_responses = 4; - map query_kwargs = 5; -} - -message ConversationReply { - string conversation_id = 1; - repeated string past_user_inputs = 2; - repeated string generated_responses = 3; - float time_taken = 4; - float model_time_taken = 5; + string finish_reason = 2; + int64 prompt_tokens = 3; + int64 generated_tokens = 4; + float time_taken = 5; + float model_time_taken = 6; } -message ImageReply { - repeated bytes images = 1; - repeated bool nsfw_content_detected = 2; - string mode = 3; - int64 size_w = 4; - int64 size_h = 5; - float time_taken = 6; +message MultiGenerationReply { + repeated SingleGenerationReply response = 1; } diff --git a/mii/grpc_related/proto/modelresponse_pb2.py b/mii/grpc_related/proto/modelresponse_pb2.py index 6b5294f7..c152e207 100644 --- a/mii/grpc_related/proto/modelresponse_pb2.py +++ b/mii/grpc_related/proto/modelresponse_pb2.py @@ -2,7 +2,6 @@ # SPDX-License-Identifier: Apache-2.0 # DeepSpeed Team -# -*- coding: utf-8 -*- # Generated by the protocol buffer compiler. DO NOT EDIT! # source: modelresponse.proto """Generated protocol buffer code.""" @@ -17,7 +16,7 @@ from google.protobuf import empty_pb2 as google_dot_protobuf_dot_empty__pb2 DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile( - b'\n\x13modelresponse.proto\x12\rmodelresponse\x1a\x1bgoogle/protobuf/empty.proto\"\x88\x01\n\nDictionary\x12\x35\n\x06values\x18\x01 \x03(\x0b\x32%.modelresponse.Dictionary.ValuesEntry\x1a\x43\n\x0bValuesEntry\x12\x0b\n\x03key\x18\x01 \x01(\t\x12#\n\x05value\x18\x02 \x01(\x0b\x32\x14.modelresponse.Value:\x02\x38\x01\"\x8c\x01\n\x05Value\x12\x10\n\x06svalue\x18\x01 \x01(\tH\x00\x12\x10\n\x06ivalue\x18\x02 \x01(\x03H\x00\x12\x10\n\x06\x66value\x18\x03 \x01(\x02H\x00\x12\x10\n\x06\x62value\x18\x04 \x01(\x08H\x00\x12+\n\x06mvalue\x18\x05 \x01(\x0b\x32\x19.modelresponse.DictionaryH\x00\x42\x0e\n\x0coneof_values\"\x1f\n\tSessionID\x12\x12\n\nsession_id\x18\x01 \x01(\t\"\xbb\x01\n\x13SingleStringRequest\x12\x0f\n\x07request\x18\x01 \x01(\t\x12I\n\x0cquery_kwargs\x18\x02 \x03(\x0b\x32\x33.modelresponse.SingleStringRequest.QueryKwargsEntry\x1aH\n\x10QueryKwargsEntry\x12\x0b\n\x03key\x18\x01 \x01(\t\x12#\n\x05value\x18\x02 \x01(\x0b\x32\x14.modelresponse.Value:\x02\x38\x01\"\xb9\x01\n\x12MultiStringRequest\x12\x0f\n\x07request\x18\x01 \x03(\t\x12H\n\x0cquery_kwargs\x18\x02 \x03(\x0b\x32\x32.modelresponse.MultiStringRequest.QueryKwargsEntry\x1aH\n\x10QueryKwargsEntry\x12\x0b\n\x03key\x18\x01 \x01(\t\x12#\n\x05value\x18\x02 \x01(\x0b\x32\x14.modelresponse.Value:\x02\x38\x01\"S\n\x11SingleStringReply\x12\x10\n\x08response\x18\x01 \x01(\t\x12\x12\n\ntime_taken\x18\x02 \x01(\x02\x12\x18\n\x10model_time_taken\x18\x03 \x01(\x02\"R\n\x10MultiStringReply\x12\x10\n\x08response\x18\x01 \x03(\t\x12\x12\n\ntime_taken\x18\x02 \x01(\x02\x12\x18\n\x10model_time_taken\x18\x03 \x01(\x02\"[\n\x11GenerationDetails\x12\x15\n\rfinish_reason\x18\x01 \x01(\t\x12\x15\n\rprompt_tokens\x18\x02 \x01(\x03\x12\x18\n\x10generated_tokens\x18\x03 \x01(\x03\"\x95\x01\n\x0fGenerationReply\x12\x10\n\x08response\x18\x01 \x03(\t\x12\x0f\n\x07indices\x18\x02 \x03(\x03\x12\x31\n\x07\x64\x65tails\x18\x03 \x03(\x0b\x32 .modelresponse.GenerationDetails\x12\x12\n\ntime_taken\x18\x04 \x01(\x02\x12\x18\n\x10model_time_taken\x18\x05 \x01(\x02\"\xb9\x01\n\tQARequest\x12\x10\n\x08question\x18\x01 \x01(\t\x12\x0f\n\x07\x63ontext\x18\x02 \x01(\t\x12?\n\x0cquery_kwargs\x18\x03 \x03(\x0b\x32).modelresponse.QARequest.QueryKwargsEntry\x1aH\n\x10QueryKwargsEntry\x12\x0b\n\x03key\x18\x01 \x01(\t\x12#\n\x05value\x18\x02 \x01(\x0b\x32\x14.modelresponse.Value:\x02\x38\x01\"\x88\x02\n\x13\x43onversationRequest\x12\x0c\n\x04text\x18\x01 \x01(\t\x12\x17\n\x0f\x63onversation_id\x18\x02 \x01(\t\x12\x18\n\x10past_user_inputs\x18\x03 \x03(\t\x12\x1b\n\x13generated_responses\x18\x04 \x03(\t\x12I\n\x0cquery_kwargs\x18\x05 \x03(\x0b\x32\x33.modelresponse.ConversationRequest.QueryKwargsEntry\x1aH\n\x10QueryKwargsEntry\x12\x0b\n\x03key\x18\x01 \x01(\t\x12#\n\x05value\x18\x02 \x01(\x0b\x32\x14.modelresponse.Value:\x02\x38\x01\"\x91\x01\n\x11\x43onversationReply\x12\x17\n\x0f\x63onversation_id\x18\x01 \x01(\t\x12\x18\n\x10past_user_inputs\x18\x02 \x03(\t\x12\x1b\n\x13generated_responses\x18\x03 \x03(\t\x12\x12\n\ntime_taken\x18\x04 \x01(\x02\x12\x18\n\x10model_time_taken\x18\x05 \x01(\x02\"}\n\nImageReply\x12\x0e\n\x06images\x18\x01 \x03(\x0c\x12\x1d\n\x15nsfw_content_detected\x18\x02 \x03(\x08\x12\x0c\n\x04mode\x18\x03 \x01(\t\x12\x0e\n\x06size_w\x18\x04 \x01(\x03\x12\x0e\n\x06size_h\x18\x05 \x01(\x03\x12\x12\n\ntime_taken\x18\x06 \x01(\x02\x32\xb3\x07\n\rModelResponse\x12=\n\tTerminate\x12\x16.google.protobuf.Empty\x1a\x16.google.protobuf.Empty\"\x00\x12\x43\n\rCreateSession\x12\x18.modelresponse.SessionID\x1a\x16.google.protobuf.Empty\"\x00\x12\x44\n\x0e\x44\x65stroySession\x12\x18.modelresponse.SessionID\x1a\x16.google.protobuf.Empty\"\x00\x12V\n\x0eGeneratorReply\x12!.modelresponse.MultiStringRequest\x1a\x1f.modelresponse.MultiStringReply\"\x00\x12]\n\x13\x43lassificationReply\x12\".modelresponse.SingleStringRequest\x1a .modelresponse.SingleStringReply\"\x00\x12V\n\x16QuestionAndAnswerReply\x12\x18.modelresponse.QARequest\x1a .modelresponse.SingleStringReply\"\x00\x12W\n\rFillMaskReply\x12\".modelresponse.SingleStringRequest\x1a .modelresponse.SingleStringReply\"\x00\x12\x62\n\x18TokenClassificationReply\x12\".modelresponse.SingleStringRequest\x1a .modelresponse.SingleStringReply\"\x00\x12]\n\x13\x43onversationalReply\x12\".modelresponse.ConversationRequest\x1a .modelresponse.ConversationReply\"\x00\x12N\n\x0cTxt2ImgReply\x12!.modelresponse.MultiStringRequest\x1a\x19.modelresponse.ImageReply\"\x00\x12]\n\x14GeneratorReplyStream\x12!.modelresponse.MultiStringRequest\x1a\x1e.modelresponse.GenerationReply\"\x00\x30\x01\x62\x06proto3' + b'\n\x13modelresponse.proto\x12\rmodelresponse\x1a\x1bgoogle/protobuf/empty.proto\"\x88\x01\n\nDictionary\x12\x35\n\x06values\x18\x01 \x03(\x0b\x32%.modelresponse.Dictionary.ValuesEntry\x1a\x43\n\x0bValuesEntry\x12\x0b\n\x03key\x18\x01 \x01(\t\x12#\n\x05value\x18\x02 \x01(\x0b\x32\x14.modelresponse.Value:\x02\x38\x01\"\x8c\x01\n\x05Value\x12\x10\n\x06svalue\x18\x01 \x01(\tH\x00\x12\x10\n\x06ivalue\x18\x02 \x01(\x03H\x00\x12\x10\n\x06\x66value\x18\x03 \x01(\x02H\x00\x12\x10\n\x06\x62value\x18\x04 \x01(\x08H\x00\x12+\n\x06mvalue\x18\x05 \x01(\x0b\x32\x19.modelresponse.DictionaryH\x00\x42\x0e\n\x0coneof_values\"\xbb\x01\n\x13SingleStringRequest\x12\x0f\n\x07request\x18\x01 \x01(\t\x12I\n\x0cquery_kwargs\x18\x02 \x03(\x0b\x32\x33.modelresponse.SingleStringRequest.QueryKwargsEntry\x1aH\n\x10QueryKwargsEntry\x12\x0b\n\x03key\x18\x01 \x01(\t\x12#\n\x05value\x18\x02 \x01(\x0b\x32\x14.modelresponse.Value:\x02\x38\x01\"\xb9\x01\n\x12MultiStringRequest\x12\x0f\n\x07request\x18\x01 \x03(\t\x12H\n\x0cquery_kwargs\x18\x02 \x03(\x0b\x32\x32.modelresponse.MultiStringRequest.QueryKwargsEntry\x1aH\n\x10QueryKwargsEntry\x12\x0b\n\x03key\x18\x01 \x01(\t\x12#\n\x05value\x18\x02 \x01(\x0b\x32\x14.modelresponse.Value:\x02\x38\x01\"\x9f\x01\n\x15SingleGenerationReply\x12\x10\n\x08response\x18\x01 \x01(\t\x12\x15\n\rfinish_reason\x18\x02 \x01(\t\x12\x15\n\rprompt_tokens\x18\x03 \x01(\x03\x12\x18\n\x10generated_tokens\x18\x04 \x01(\x03\x12\x12\n\ntime_taken\x18\x05 \x01(\x02\x12\x18\n\x10model_time_taken\x18\x06 \x01(\x02\"N\n\x14MultiGenerationReply\x12\x36\n\x08response\x18\x01 \x03(\x0b\x32$.modelresponse.SingleGenerationReply2\x8e\x02\n\rModelResponse\x12=\n\tTerminate\x12\x16.google.protobuf.Empty\x1a\x16.google.protobuf.Empty\"\x00\x12Z\n\x0eGeneratorReply\x12!.modelresponse.MultiStringRequest\x1a#.modelresponse.MultiGenerationReply\"\x00\x12\x62\n\x14GeneratorReplyStream\x12!.modelresponse.MultiStringRequest\x1a#.modelresponse.MultiGenerationReply\"\x00\x30\x01\x62\x06proto3' ) _globals = globals() @@ -31,46 +30,24 @@ _SINGLESTRINGREQUEST_QUERYKWARGSENTRY._serialized_options = b'8\001' _MULTISTRINGREQUEST_QUERYKWARGSENTRY._options = None _MULTISTRINGREQUEST_QUERYKWARGSENTRY._serialized_options = b'8\001' - _QAREQUEST_QUERYKWARGSENTRY._options = None - _QAREQUEST_QUERYKWARGSENTRY._serialized_options = b'8\001' - _CONVERSATIONREQUEST_QUERYKWARGSENTRY._options = None - _CONVERSATIONREQUEST_QUERYKWARGSENTRY._serialized_options = b'8\001' _globals['_DICTIONARY']._serialized_start = 68 _globals['_DICTIONARY']._serialized_end = 204 _globals['_DICTIONARY_VALUESENTRY']._serialized_start = 137 _globals['_DICTIONARY_VALUESENTRY']._serialized_end = 204 _globals['_VALUE']._serialized_start = 207 _globals['_VALUE']._serialized_end = 347 - _globals['_SESSIONID']._serialized_start = 349 - _globals['_SESSIONID']._serialized_end = 380 - _globals['_SINGLESTRINGREQUEST']._serialized_start = 383 - _globals['_SINGLESTRINGREQUEST']._serialized_end = 570 - _globals['_SINGLESTRINGREQUEST_QUERYKWARGSENTRY']._serialized_start = 498 - _globals['_SINGLESTRINGREQUEST_QUERYKWARGSENTRY']._serialized_end = 570 - _globals['_MULTISTRINGREQUEST']._serialized_start = 573 - _globals['_MULTISTRINGREQUEST']._serialized_end = 758 - _globals['_MULTISTRINGREQUEST_QUERYKWARGSENTRY']._serialized_start = 498 - _globals['_MULTISTRINGREQUEST_QUERYKWARGSENTRY']._serialized_end = 570 - _globals['_SINGLESTRINGREPLY']._serialized_start = 760 - _globals['_SINGLESTRINGREPLY']._serialized_end = 843 - _globals['_MULTISTRINGREPLY']._serialized_start = 845 - _globals['_MULTISTRINGREPLY']._serialized_end = 927 - _globals['_GENERATIONDETAILS']._serialized_start = 929 - _globals['_GENERATIONDETAILS']._serialized_end = 1020 - _globals['_GENERATIONREPLY']._serialized_start = 1023 - _globals['_GENERATIONREPLY']._serialized_end = 1172 - _globals['_QAREQUEST']._serialized_start = 1175 - _globals['_QAREQUEST']._serialized_end = 1360 - _globals['_QAREQUEST_QUERYKWARGSENTRY']._serialized_start = 498 - _globals['_QAREQUEST_QUERYKWARGSENTRY']._serialized_end = 570 - _globals['_CONVERSATIONREQUEST']._serialized_start = 1363 - _globals['_CONVERSATIONREQUEST']._serialized_end = 1627 - _globals['_CONVERSATIONREQUEST_QUERYKWARGSENTRY']._serialized_start = 498 - _globals['_CONVERSATIONREQUEST_QUERYKWARGSENTRY']._serialized_end = 570 - _globals['_CONVERSATIONREPLY']._serialized_start = 1630 - _globals['_CONVERSATIONREPLY']._serialized_end = 1775 - _globals['_IMAGEREPLY']._serialized_start = 1777 - _globals['_IMAGEREPLY']._serialized_end = 1902 - _globals['_MODELRESPONSE']._serialized_start = 1905 - _globals['_MODELRESPONSE']._serialized_end = 2852 + _globals['_SINGLESTRINGREQUEST']._serialized_start = 350 + _globals['_SINGLESTRINGREQUEST']._serialized_end = 537 + _globals['_SINGLESTRINGREQUEST_QUERYKWARGSENTRY']._serialized_start = 465 + _globals['_SINGLESTRINGREQUEST_QUERYKWARGSENTRY']._serialized_end = 537 + _globals['_MULTISTRINGREQUEST']._serialized_start = 540 + _globals['_MULTISTRINGREQUEST']._serialized_end = 725 + _globals['_MULTISTRINGREQUEST_QUERYKWARGSENTRY']._serialized_start = 465 + _globals['_MULTISTRINGREQUEST_QUERYKWARGSENTRY']._serialized_end = 537 + _globals['_SINGLEGENERATIONREPLY']._serialized_start = 728 + _globals['_SINGLEGENERATIONREPLY']._serialized_end = 887 + _globals['_MULTIGENERATIONREPLY']._serialized_start = 889 + _globals['_MULTIGENERATIONREPLY']._serialized_end = 967 + _globals['_MODELRESPONSE']._serialized_start = 970 + _globals['_MODELRESPONSE']._serialized_end = 1240 # @@protoc_insertion_point(module_scope) diff --git a/mii/grpc_related/proto/modelresponse_pb2_grpc.py b/mii/grpc_related/proto/modelresponse_pb2_grpc.py index 4f16a368..8da300b6 100644 --- a/mii/grpc_related/proto/modelresponse_pb2_grpc.py +++ b/mii/grpc_related/proto/modelresponse_pb2_grpc.py @@ -24,121 +24,31 @@ def __init__(self, channel): SerializeToString, response_deserializer=google_dot_protobuf_dot_empty__pb2.Empty.FromString, ) - self.CreateSession = channel.unary_unary( - '/modelresponse.ModelResponse/CreateSession', - request_serializer=modelresponse__pb2.SessionID.SerializeToString, - response_deserializer=google_dot_protobuf_dot_empty__pb2.Empty.FromString, - ) - self.DestroySession = channel.unary_unary( - '/modelresponse.ModelResponse/DestroySession', - request_serializer=modelresponse__pb2.SessionID.SerializeToString, - response_deserializer=google_dot_protobuf_dot_empty__pb2.Empty.FromString, - ) self.GeneratorReply = channel.unary_unary( '/modelresponse.ModelResponse/GeneratorReply', request_serializer=modelresponse__pb2.MultiStringRequest.SerializeToString, - response_deserializer=modelresponse__pb2.MultiStringReply.FromString, - ) - self.ClassificationReply = channel.unary_unary( - '/modelresponse.ModelResponse/ClassificationReply', - request_serializer=modelresponse__pb2.SingleStringRequest.SerializeToString, - response_deserializer=modelresponse__pb2.SingleStringReply.FromString, - ) - self.QuestionAndAnswerReply = channel.unary_unary( - '/modelresponse.ModelResponse/QuestionAndAnswerReply', - request_serializer=modelresponse__pb2.QARequest.SerializeToString, - response_deserializer=modelresponse__pb2.SingleStringReply.FromString, - ) - self.FillMaskReply = channel.unary_unary( - '/modelresponse.ModelResponse/FillMaskReply', - request_serializer=modelresponse__pb2.SingleStringRequest.SerializeToString, - response_deserializer=modelresponse__pb2.SingleStringReply.FromString, - ) - self.TokenClassificationReply = channel.unary_unary( - '/modelresponse.ModelResponse/TokenClassificationReply', - request_serializer=modelresponse__pb2.SingleStringRequest.SerializeToString, - response_deserializer=modelresponse__pb2.SingleStringReply.FromString, - ) - self.ConversationalReply = channel.unary_unary( - '/modelresponse.ModelResponse/ConversationalReply', - request_serializer=modelresponse__pb2.ConversationRequest.SerializeToString, - response_deserializer=modelresponse__pb2.ConversationReply.FromString, - ) - self.Txt2ImgReply = channel.unary_unary( - '/modelresponse.ModelResponse/Txt2ImgReply', - request_serializer=modelresponse__pb2.MultiStringRequest.SerializeToString, - response_deserializer=modelresponse__pb2.ImageReply.FromString, + response_deserializer=modelresponse__pb2.MultiGenerationReply.FromString, ) self.GeneratorReplyStream = channel.unary_stream( '/modelresponse.ModelResponse/GeneratorReplyStream', request_serializer=modelresponse__pb2.MultiStringRequest.SerializeToString, - response_deserializer=modelresponse__pb2.GenerationReply.FromString, + response_deserializer=modelresponse__pb2.MultiGenerationReply.FromString, ) class ModelResponseServicer(object): """Missing associated documentation comment in .proto file.""" - ERROR_MSG = 'Method not implemented!' - def Terminate(self, request, context): """Missing associated documentation comment in .proto file.""" context.set_code(grpc.StatusCode.UNIMPLEMENTED) - context.set_details(self.ERROR_MSG) - raise NotImplementedError(self.ERROR_MSG) - - def CreateSession(self, request, context): - """Missing associated documentation comment in .proto file.""" - context.set_code(grpc.StatusCode.UNIMPLEMENTED) - context.set_details(self.ERROR_MSG) - raise NotImplementedError(self.ERROR_MSG) - - def DestroySession(self, request, context): - """Missing associated documentation comment in .proto file.""" - context.set_code(grpc.StatusCode.UNIMPLEMENTED) - context.set_details(self.ERROR_MSG) - raise NotImplementedError(self.ERROR_MSG) + context.set_details('Method not implemented!') + raise NotImplementedError('Method not implemented!') def GeneratorReply(self, request, context): """Missing associated documentation comment in .proto file.""" context.set_code(grpc.StatusCode.UNIMPLEMENTED) - context.set_details(self.ERROR_MSG) - raise NotImplementedError(self.ERROR_MSG) - - def ClassificationReply(self, request, context): - """Missing associated documentation comment in .proto file.""" - context.set_code(grpc.StatusCode.UNIMPLEMENTED) - context.set_details(self.ERROR_MSG) - raise NotImplementedError(self.ERROR_MSG) - - def QuestionAndAnswerReply(self, request, context): - """Missing associated documentation comment in .proto file.""" - context.set_code(grpc.StatusCode.UNIMPLEMENTED) - context.set_details(self.ERROR_MSG) - raise NotImplementedError(self.ERROR_MSG) - - def FillMaskReply(self, request, context): - """Missing associated documentation comment in .proto file.""" - context.set_code(grpc.StatusCode.UNIMPLEMENTED) - context.set_details(self.ERROR_MSG) - raise NotImplementedError(self.ERROR_MSG) - - def TokenClassificationReply(self, request, context): - """Missing associated documentation comment in .proto file.""" - context.set_code(grpc.StatusCode.UNIMPLEMENTED) - context.set_details(self.ERROR_MSG) - raise NotImplementedError(self.ERROR_MSG) - - def ConversationalReply(self, request, context): - """Missing associated documentation comment in .proto file.""" - context.set_code(grpc.StatusCode.UNIMPLEMENTED) - context.set_details(self.ERROR_MSG) - raise NotImplementedError(self.ERROR_MSG) - - def Txt2ImgReply(self, request, context): - """Missing associated documentation comment in .proto file.""" - context.set_code(grpc.StatusCode.UNIMPLEMENTED) - context.set_details(self.ERROR_MSG) - raise NotImplementedError(self.ERROR_MSG) + context.set_details('Method not implemented!') + raise NotImplementedError('Method not implemented!') def GeneratorReplyStream(self, request, context): """Missing associated documentation comment in .proto file.""" @@ -156,67 +66,19 @@ def add_ModelResponseServicer_to_server(servicer, server): response_serializer=google_dot_protobuf_dot_empty__pb2.Empty. SerializeToString, ), - 'CreateSession': - grpc.unary_unary_rpc_method_handler( - servicer.CreateSession, - request_deserializer=modelresponse__pb2.SessionID.FromString, - response_serializer=google_dot_protobuf_dot_empty__pb2.Empty. - SerializeToString, - ), - 'DestroySession': - grpc.unary_unary_rpc_method_handler( - servicer.DestroySession, - request_deserializer=modelresponse__pb2.SessionID.FromString, - response_serializer=google_dot_protobuf_dot_empty__pb2.Empty. - SerializeToString, - ), 'GeneratorReply': grpc.unary_unary_rpc_method_handler( servicer.GeneratorReply, request_deserializer=modelresponse__pb2.MultiStringRequest.FromString, - response_serializer=modelresponse__pb2.MultiStringReply.SerializeToString, - ), - 'ClassificationReply': - grpc.unary_unary_rpc_method_handler( - servicer.ClassificationReply, - request_deserializer=modelresponse__pb2.SingleStringRequest.FromString, - response_serializer=modelresponse__pb2.SingleStringReply.SerializeToString, - ), - 'QuestionAndAnswerReply': - grpc.unary_unary_rpc_method_handler( - servicer.QuestionAndAnswerReply, - request_deserializer=modelresponse__pb2.QARequest.FromString, - response_serializer=modelresponse__pb2.SingleStringReply.SerializeToString, - ), - 'FillMaskReply': - grpc.unary_unary_rpc_method_handler( - servicer.FillMaskReply, - request_deserializer=modelresponse__pb2.SingleStringRequest.FromString, - response_serializer=modelresponse__pb2.SingleStringReply.SerializeToString, - ), - 'TokenClassificationReply': - grpc.unary_unary_rpc_method_handler( - servicer.TokenClassificationReply, - request_deserializer=modelresponse__pb2.SingleStringRequest.FromString, - response_serializer=modelresponse__pb2.SingleStringReply.SerializeToString, - ), - 'ConversationalReply': - grpc.unary_unary_rpc_method_handler( - servicer.ConversationalReply, - request_deserializer=modelresponse__pb2.ConversationRequest.FromString, - response_serializer=modelresponse__pb2.ConversationReply.SerializeToString, - ), - 'Txt2ImgReply': - grpc.unary_unary_rpc_method_handler( - servicer.Txt2ImgReply, - request_deserializer=modelresponse__pb2.MultiStringRequest.FromString, - response_serializer=modelresponse__pb2.ImageReply.SerializeToString, + response_serializer=modelresponse__pb2.MultiGenerationReply. + SerializeToString, ), 'GeneratorReplyStream': grpc.unary_stream_rpc_method_handler( servicer.GeneratorReplyStream, request_deserializer=modelresponse__pb2.MultiStringRequest.FromString, - response_serializer=modelresponse__pb2.GenerationReply.SerializeToString, + response_serializer=modelresponse__pb2.MultiGenerationReply. + SerializeToString, ), } generic_handler = grpc.method_handlers_generic_handler('modelresponse.ModelResponse', @@ -253,58 +115,6 @@ def Terminate(request, timeout, metadata) - @staticmethod - def CreateSession(request, - target, - options=(), - channel_credentials=None, - call_credentials=None, - insecure=False, - compression=None, - wait_for_ready=None, - timeout=None, - metadata=None): - return grpc.experimental.unary_unary( - request, - target, - '/modelresponse.ModelResponse/CreateSession', - modelresponse__pb2.SessionID.SerializeToString, - google_dot_protobuf_dot_empty__pb2.Empty.FromString, - options, - channel_credentials, - insecure, - call_credentials, - compression, - wait_for_ready, - timeout, - metadata) - - @staticmethod - def DestroySession(request, - target, - options=(), - channel_credentials=None, - call_credentials=None, - insecure=False, - compression=None, - wait_for_ready=None, - timeout=None, - metadata=None): - return grpc.experimental.unary_unary( - request, - target, - '/modelresponse.ModelResponse/DestroySession', - modelresponse__pb2.SessionID.SerializeToString, - google_dot_protobuf_dot_empty__pb2.Empty.FromString, - options, - channel_credentials, - insecure, - call_credentials, - compression, - wait_for_ready, - timeout, - metadata) - @staticmethod def GeneratorReply(request, target, @@ -321,163 +131,7 @@ def GeneratorReply(request, target, '/modelresponse.ModelResponse/GeneratorReply', modelresponse__pb2.MultiStringRequest.SerializeToString, - modelresponse__pb2.MultiStringReply.FromString, - options, - channel_credentials, - insecure, - call_credentials, - compression, - wait_for_ready, - timeout, - metadata) - - @staticmethod - def ClassificationReply(request, - target, - options=(), - channel_credentials=None, - call_credentials=None, - insecure=False, - compression=None, - wait_for_ready=None, - timeout=None, - metadata=None): - return grpc.experimental.unary_unary( - request, - target, - '/modelresponse.ModelResponse/ClassificationReply', - modelresponse__pb2.SingleStringRequest.SerializeToString, - modelresponse__pb2.SingleStringReply.FromString, - options, - channel_credentials, - insecure, - call_credentials, - compression, - wait_for_ready, - timeout, - metadata) - - @staticmethod - def QuestionAndAnswerReply(request, - target, - options=(), - channel_credentials=None, - call_credentials=None, - insecure=False, - compression=None, - wait_for_ready=None, - timeout=None, - metadata=None): - return grpc.experimental.unary_unary( - request, - target, - '/modelresponse.ModelResponse/QuestionAndAnswerReply', - modelresponse__pb2.QARequest.SerializeToString, - modelresponse__pb2.SingleStringReply.FromString, - options, - channel_credentials, - insecure, - call_credentials, - compression, - wait_for_ready, - timeout, - metadata) - - @staticmethod - def FillMaskReply(request, - target, - options=(), - channel_credentials=None, - call_credentials=None, - insecure=False, - compression=None, - wait_for_ready=None, - timeout=None, - metadata=None): - return grpc.experimental.unary_unary( - request, - target, - '/modelresponse.ModelResponse/FillMaskReply', - modelresponse__pb2.SingleStringRequest.SerializeToString, - modelresponse__pb2.SingleStringReply.FromString, - options, - channel_credentials, - insecure, - call_credentials, - compression, - wait_for_ready, - timeout, - metadata) - - @staticmethod - def TokenClassificationReply(request, - target, - options=(), - channel_credentials=None, - call_credentials=None, - insecure=False, - compression=None, - wait_for_ready=None, - timeout=None, - metadata=None): - return grpc.experimental.unary_unary( - request, - target, - '/modelresponse.ModelResponse/TokenClassificationReply', - modelresponse__pb2.SingleStringRequest.SerializeToString, - modelresponse__pb2.SingleStringReply.FromString, - options, - channel_credentials, - insecure, - call_credentials, - compression, - wait_for_ready, - timeout, - metadata) - - @staticmethod - def ConversationalReply(request, - target, - options=(), - channel_credentials=None, - call_credentials=None, - insecure=False, - compression=None, - wait_for_ready=None, - timeout=None, - metadata=None): - return grpc.experimental.unary_unary( - request, - target, - '/modelresponse.ModelResponse/ConversationalReply', - modelresponse__pb2.ConversationRequest.SerializeToString, - modelresponse__pb2.ConversationReply.FromString, - options, - channel_credentials, - insecure, - call_credentials, - compression, - wait_for_ready, - timeout, - metadata) - - @staticmethod - def Txt2ImgReply(request, - target, - options=(), - channel_credentials=None, - call_credentials=None, - insecure=False, - compression=None, - wait_for_ready=None, - timeout=None, - metadata=None): - return grpc.experimental.unary_unary( - request, - target, - '/modelresponse.ModelResponse/Txt2ImgReply', - modelresponse__pb2.MultiStringRequest.SerializeToString, - modelresponse__pb2.ImageReply.FromString, + modelresponse__pb2.MultiGenerationReply.FromString, options, channel_credentials, insecure, @@ -503,7 +157,7 @@ def GeneratorReplyStream(request, target, '/modelresponse.ModelResponse/GeneratorReplyStream', modelresponse__pb2.MultiStringRequest.SerializeToString, - modelresponse__pb2.GenerationReply.FromString, + modelresponse__pb2.MultiGenerationReply.FromString, options, channel_credentials, insecure, diff --git a/mii/grpc_related/restful_gateway.py b/mii/grpc_related/restful_gateway.py index dc2dab71..5c2bc48a 100644 --- a/mii/grpc_related/restful_gateway.py +++ b/mii/grpc_related/restful_gateway.py @@ -2,11 +2,12 @@ # SPDX-License-Identifier: Apache-2.0 # DeepSpeed Team +import json import threading import time + from flask import Flask, request from flask_restful import Resource, Api -from google.protobuf.json_format import MessageToJson from werkzeug.serving import make_server import mii @@ -29,7 +30,8 @@ def __init__(self): def post(self): data = request.get_json() result = client.generate(**data) - return MessageToJson(result) + result_json = json.dumps([r.to_msg_dict() for r in result]) + return result_json app = Flask("RestfulGateway") diff --git a/mii/grpc_related/task_methods.py b/mii/grpc_related/task_methods.py index 7d37805a..77c4a3fc 100644 --- a/mii/grpc_related/task_methods.py +++ b/mii/grpc_related/task_methods.py @@ -4,7 +4,11 @@ # DeepSpeed Team from abc import ABC, abstractmethod +from typing import Any, Dict, List, Tuple +from google.protobuf.message import Message + +from mii.batching.data_classes import Response from mii.constants import TaskType from mii.grpc_related.proto import modelresponse_pb2 from mii.utils import kwarg_dict_to_proto, unpack_proto_query_kwargs @@ -22,37 +26,27 @@ def single_string_response_to_proto(self, response, time_taken, model_time_taken model_time_taken=model_time_taken) -def multi_string_request_to_proto(self, request_dict, **query_kwargs): - return modelresponse_pb2.MultiStringRequest( - request=request_dict["query"] if isinstance(request_dict["query"], - list) else [request_dict["query"]], - query_kwargs=kwarg_dict_to_proto(query_kwargs), - ) - - -def proto_request_to_list(self, request): - prompts = [r for r in request.request] - kwargs = unpack_proto_query_kwargs(request.query_kwargs) - return prompts, kwargs - - class TaskMethods(ABC): @property @abstractmethod def method(self): ... - def pack_request_to_proto(self, request_dict, **query_kwargs): - return request_dict, query_kwargs + @abstractmethod + def pack_request_to_proto(self, request, **query_kwargs): + ... - def unpack_request_from_proto(self, request): - return request + @abstractmethod + def unpack_request_from_proto(self, proto_request): + ... - def pack_response_to_proto(self, response, time_taken, model_time_taken): - return response, time_taken, model_time_taken + @abstractmethod + def pack_response_to_proto(self, response): + ... - def unpack_response_from_proto(self, response): - return response + @abstractmethod + def unpack_response_from_proto(self, proto_response): + ... class TextGenerationMethods(TaskMethods): @@ -64,31 +58,50 @@ def method(self): def method_stream_out(self): return "GeneratorReplyStream" - pack_request_to_proto = multi_string_request_to_proto - unpack_request_from_proto = proto_request_to_list - - def pack_response_to_proto(self, responses, time_taken, model_time_taken): - text_responses = [] - details = [] - - # Response a nested list of dicts - # [Sample, 1, Dict] - for response in responses: - text = response.generated_text - text_responses.append(text) - details.append( - modelresponse_pb2.GenerationDetails( - finish_reason=str(response.finish_reason), - prompt_tokens=response.prompt_length, - generated_tokens=response.generated_length)) - - return modelresponse_pb2.GenerationReply( - response=text_responses, - indices=[0], - details=details, - time_taken=time_taken, - model_time_taken=model_time_taken, + def pack_request_to_proto(self, + prompts: List[str], + **query_kwargs: Dict[str, + Any]) -> Message: + proto_request = modelresponse_pb2.MultiStringRequest( + request=prompts, + query_kwargs=kwarg_dict_to_proto(query_kwargs), ) + return proto_request + + def unpack_request_from_proto(self, + proto_request: Message) -> Tuple[List[str], + Dict[str, + Any]]: + prompts = [r for r in proto_request.request] + kwargs = unpack_proto_query_kwargs(proto_request.query_kwargs) + return prompts, kwargs + + def pack_response_to_proto(self, responses: List[Response]) -> Message: + proto_responses = [] + for r in responses: + proto_responses.append( + modelresponse_pb2.SingleGenerationReply( + response=r.generated_text, + finish_reason=str(r.finish_reason.value), + prompt_tokens=r.prompt_length, + generated_tokens=r.generated_length, + time_taken=-1, + model_time_taken=-1, + )) + + return modelresponse_pb2.MultiGenerationReply(response=proto_responses, ) + + def unpack_response_from_proto(self, response: Message) -> List[Response]: + response_batch = [] + for r in response.response: + response_batch.append( + Response( + generated_text=r.response, + prompt_length=r.prompt_tokens, + generated_length=r.generated_tokens, + finish_reason=r.finish_reason, + )) + return response_batch TASK_METHODS_DICT = { diff --git a/tests/test_deployment.py b/tests/test_deployment.py index 125b34c7..6e897d82 100644 --- a/tests/test_deployment.py +++ b/tests/test_deployment.py @@ -14,30 +14,30 @@ def test_single_gpu(deployment, query): - output = deployment(query) - assert output, "output is empty" + outputs = deployment(query) + assert outputs[0], "output is empty" def test_streaming(deployment, query): - output = [] + outputs = [] def callback(response): - output.append(response.response) + outputs.append(response[0].generated_text) deployment(query, streaming_fn=callback) - assert output, "output is empty" + assert outputs, "output is empty" def test_multi_prompt(deployment, query): - output = deployment([query] * 4) - for r in output.response: + outputs = deployment([query] * 4) + for r in outputs: assert r, "output is empty" @pytest.mark.parametrize("tensor_parallel", [2]) def test_multi_gpu(deployment, query): - output = deployment(query) - assert output, "output is empty" + outputs = deployment(query) + assert outputs[0], "output is empty" @pytest.mark.parametrize("replica_num", [2]) @@ -45,9 +45,9 @@ def test_multi_replica(deployment, query): deployment_name = deployment.mii_config.deployment_name start = time.time() - output = mii.client(deployment_name)(query, max_length=128, ignore_eos=True) + outputs = mii.client(deployment_name)(query, max_length=128, ignore_eos=True) end = time.time() - assert output, "output is empty" + assert outputs[0], "output is empty" single_query_time = end - start procs = [] @@ -77,7 +77,7 @@ def test_multi_replica(deployment, query): def test_query_kwargs(deployment, query): # test ignore_eos - output = deployment( + outputs = deployment( query, max_length=128, min_new_tokens=16, @@ -86,14 +86,14 @@ def test_query_kwargs(deployment, query): top_k=50, temperature=0.9, ) - assert output, "output is empty" + assert outputs[0], "output is empty" def test_do_sample(deployment, query): output_0 = deployment(query, do_sample=False, max_length=128) output_1 = deployment(query, do_sample=False, max_length=128) assert ( - output_0.response == output_1.response + output_0[0] == output_1[0] ), "do_sample=False should always return the same output" @@ -105,15 +105,15 @@ def test_stop_token(deployment, query): def test_return_full_text(deployment, query): - output = deployment(query, max_length=128, return_full_text=True) - assert output.response[0].startswith(query), "output should start with the prompt" + outputs = deployment(query, max_length=128, return_full_text=True) + assert outputs[0].generated_text.startswith(query), "output should start with the prompt" @pytest.mark.parametrize("enable_restful_api", [True]) def test_restful_api(deployment, query, deployment_name, restful_api_port): # Verify deployment is running - output = deployment(query, max_length=128) - assert output, "output is empty" + outputs = deployment(query, max_length=128) + assert outputs[0], "output is empty" # Verify REST API url = f"http://localhost:{restful_api_port}/mii/{deployment_name}" @@ -123,4 +123,4 @@ def test_restful_api(deployment, query, deployment_name, restful_api_port): data=json_params, headers={"Content-Type": "application/json"}) assert result.status_code == 200 - assert "response" in result.json() + assert "generated_text" in result.json()