Skip to content

Commit

Permalink
Streaming for serving with chat's generate function (#1426)
Browse files Browse the repository at this point in the history
  • Loading branch information
rasbt authored Jun 4, 2024
1 parent fa88952 commit 0f3bca7
Show file tree
Hide file tree
Showing 3 changed files with 86 additions and 16 deletions.
Empty file added litgpt/deploy/__init__.py
Empty file.
100 changes: 85 additions & 15 deletions litgpt/deploy/serve.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,8 @@
from litgpt.model import GPT
from litgpt.config import Config
from litgpt.tokenizer import Tokenizer
from litgpt.generate.base import generate
from litgpt.generate.base import generate as plain_generate
from litgpt.chat.base import generate as stream_generate
from litgpt.prompts import load_prompt_style, has_prompt_style, PromptStyle
from litgpt.utils import (
extend_checkpoint_dir,
Expand All @@ -28,7 +29,7 @@
LitAPI, LitServer = object, object


class SimpleLitAPI(LitAPI):
class BaseLitAPI(LitAPI):
def __init__(self,
checkpoint_dir: Path,
precision: Optional[str] = None,
Expand Down Expand Up @@ -86,12 +87,26 @@ def decode_request(self, request: Dict[str, Any]) -> Any:
encoded = self.tokenizer.encode(prompt, device=self.device)
return encoded


class SimpleLitAPI(BaseLitAPI):
def __init__(self,
checkpoint_dir: Path,
precision: Optional[str] = None,
temperature: float = 0.8,
top_k: int = 50,
top_p: float = 1.0,
max_new_tokens: int = 50):
super().__init__(checkpoint_dir, precision, temperature, top_k, top_p, max_new_tokens)

def setup(self, device: str):
super().setup(device)

def predict(self, inputs: torch.Tensor) -> Any:
# Run the model on the input and return the output.
prompt_length = inputs.size(0)
max_returned_tokens = prompt_length + self.max_new_tokens

y = generate(
y = plain_generate(
self.model,
inputs,
max_returned_tokens,
Expand All @@ -111,6 +126,42 @@ def encode_response(self, output: torch.Tensor) -> Dict[str, Any]:
return {"output": decoded_output}


class StreamLitAPI(BaseLitAPI):
def __init__(self,
checkpoint_dir: Path,
precision: Optional[str] = None,
temperature: float = 0.8,
top_k: int = 50,
top_p: float = 1.0,
max_new_tokens: int = 50):
super().__init__(checkpoint_dir, precision, temperature, top_k, top_p, max_new_tokens)

def setup(self, device: str):
super().setup(device)

def predict(self, inputs: torch.Tensor) -> Any:
# Run the model on the input and return the output.
prompt_length = inputs.size(0)
max_returned_tokens = prompt_length + self.max_new_tokens

for block in self.model.transformer.h:
block.attn.kv_cache.reset_parameters()

yield from stream_generate(
self.model,
inputs,
max_returned_tokens,
temperature=self.temperature,
top_k=self.top_k,
top_p=self.top_p,
stop_tokens=([self.tokenizer.eos_id],)
)

def encode_response(self, output):
for out in output:
yield {"output": self.tokenizer.decode(out)}


def run_server(
checkpoint_dir: Path,
precision: Optional[str] = None,
Expand All @@ -120,7 +171,8 @@ def run_server(
max_new_tokens: int = 50,
devices: int = 1,
accelerator: str = "auto",
port: int = 8000
port: int = 8000,
stream: bool = False
) -> None:
"""Serve a LitGPT model using LitServe.
Expand Down Expand Up @@ -153,22 +205,40 @@ def run_server(
accelerator: The type of accelerator to use. For example, "auto", "cuda", "cpu", or "mps".
The "auto" setting (default) chooses a GPU if available, and otherwise uses a CPU.
port: The network port number on which the model is configured to be served.
stream: Whether to stream the responses.
"""
checkpoint_dir = extend_checkpoint_dir(checkpoint_dir)
pprint(locals())

check_valid_checkpoint_dir(checkpoint_dir, model_filename="lit_model.pth")

server = LitServer(
SimpleLitAPI(
checkpoint_dir=checkpoint_dir,
precision=precision,
temperature=temperature,
top_k=top_k,
top_p=top_p,
max_new_tokens=max_new_tokens,
),
accelerator=accelerator,
devices=devices)
if not stream:
server = LitServer(
SimpleLitAPI(
checkpoint_dir=checkpoint_dir,
precision=precision,
temperature=temperature,
top_k=top_k,
top_p=top_p,
max_new_tokens=max_new_tokens,
),
accelerator=accelerator,
devices=devices
)

else:
server = LitServer(
StreamLitAPI(
checkpoint_dir=checkpoint_dir,
precision=precision,
temperature=temperature,
top_k=top_k,
top_p=top_p,
max_new_tokens=max_new_tokens,
),
accelerator=accelerator,
devices=devices,
stream=True
)

server.run(port=port)
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ all = [
"tokenizers>=0.15.2", # pythia, falcon, redpajama
"requests>=2.31.0", # litgpt.data
"litdata==0.2.6", # litgpt.data
"litserve>=0.1.0", # litgpt.deploy
"litserve==0.1.1dev0", # litgpt.deploy
"zstandard>=0.22.0", # litgpt.data.prepare_slimpajama.py
"pandas>=1.9.0", # litgpt.data.prepare_starcoder.py
"pyarrow>=15.0.2", # litgpt.data.prepare_starcoder.py
Expand Down

0 comments on commit 0f3bca7

Please sign in to comment.