Skip to content

Commit

Permalink
add vLLM (#21)
Browse files Browse the repository at this point in the history
* update

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* update

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* update

* update

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

---------

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
  • Loading branch information
aniketmaurya and pre-commit-ci[bot] authored Feb 23, 2024
1 parent 5e6f5f3 commit a48c67e
Show file tree
Hide file tree
Showing 6 changed files with 143 additions and 28 deletions.
26 changes: 25 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,8 @@ python -m fastserve

## Usage/Examples

### Serve Mistral-7B with Llama-cpp

### Serve LLMs with Llama-cpp

```python
from fastserve.models import ServeLlamaCpp
Expand All @@ -38,6 +39,29 @@ serve.run_server()

or, run `python -m fastserve.models --model llama-cpp --model_path openhermes-2-mistral-7b.Q5_K_M.gguf` from terminal.


### Serve vLLM

```python
from fastserve.models import ServeVLLM

app = ServeVLLM("TinyLlama/TinyLlama-1.1B-Chat-v1.0")
app.run_server()
```

You can use the FastServe client that will automatically apply chat template for you -

```python
from fastserve.client import vLLMClient
from rich import print

client = vLLMClient("TinyLlama/TinyLlama-1.1B-Chat-v1.0")
response = client.chat("Write a python function to resize image to 224x224", keep_context=True)
# print(client.context)
print(response["outputs"][0]["text"])
```


### Serve SDXL Turbo

```python
Expand Down
Empty file.
47 changes: 47 additions & 0 deletions src/fastserve/client/vllm.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,47 @@
import logging

import requests


class Client:
def __init__(self):
pass


class vLLMClient(Client):
def __init__(self, model: str, base_url="http://localhost:8000/endpoint"):
from transformers import AutoTokenizer

super().__init__()
self.tokenizer = AutoTokenizer.from_pretrained(model)
self.context = []
self.base_url = base_url

def chat(self, prompt: str, keep_context=False):
new_msg = {"role": "user", "content": prompt}
if keep_context:
self.context.append(new_msg)
messages = self.context
else:
messages = [new_msg]

logging.info(messages)
chat = self.tokenizer.apply_chat_template(messages, tokenize=False)
headers = {
"accept": "application/json",
"Content-Type": "application/json",
}
data = {
"prompt": chat,
"temperature": 0.8,
"top_p": 1,
"max_tokens": 500,
"stop": [],
}

response = requests.post(self.base_url, headers=headers, json=data).json()
if keep_context:
self.context.append(
{"role": "assistant", "content": response["outputs"][0]["text"]}
)
return response
1 change: 1 addition & 0 deletions src/fastserve/models/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,3 +5,4 @@
from fastserve.models.llama_cpp import ServeLlamaCpp as ServeLlamaCpp
from fastserve.models.sdxl_turbo import ServeSDXLTurbo as ServeSDXLTurbo
from fastserve.models.ssd import ServeSSD1B as ServeSSD1B
from fastserve.models.vllm import ServeVLLM as ServeVLLM
73 changes: 46 additions & 27 deletions src/fastserve/models/vllm.py
Original file line number Diff line number Diff line change
@@ -1,46 +1,65 @@
import os
from typing import List
import logging
from typing import Any, List, Optional

from fastapi import FastAPI
from pydantic import BaseModel
from vllm import LLM, SamplingParams

tensor_parallel_size = int(os.environ.get("DEVICES", "1"))
print("tensor_parallel_size: ", tensor_parallel_size)
from fastserve.core import FastServe

llm = LLM("meta-llama/Llama-2-7b-hf", tensor_parallel_size=tensor_parallel_size)
logger = logging.getLogger(__name__)


class PromptRequest(BaseModel):
prompt: str
temperature: float = 1
prompt: str = "Write a python function to resize image to 224x224"
temperature: float = 0.8
top_p: float = 1.0
max_tokens: int = 200
stop: List[str] = []


class ResponseModel(BaseModel):
prompt: str
prompt_token_ids: List # The token IDs of the prompt.
outputs: List[str] # The output sequences of the request.
prompt_token_ids: Optional[List] = None # The token IDs of the prompt.
text: str # The output sequences of the request.
finished: bool # Whether the whole request is finished.


app = FastAPI()
class ServeVLLM(FastServe):
def __init__(
self,
model,
batch_size=1,
timeout=0.0,
*args,
**kwargs,
):
from vllm import LLM

self.llm = LLM(model)
self.args = args
self.kwargs = kwargs
super().__init__(
batch_size,
timeout,
input_schema=PromptRequest,
# response_schema=ResponseModel,
)

def __call__(self, request: PromptRequest) -> Any:
from vllm import SamplingParams

sampling_params = SamplingParams(
temperature=request.temperature,
top_p=request.top_p,
max_tokens=request.max_tokens,
)
result = self.llm.generate(request.prompt, sampling_params=sampling_params)
logger.info(result)
return result

@app.post("/serve", response_model=ResponseModel)
def serve(request: PromptRequest):
sampling_params = SamplingParams(
max_tokens=request.max_tokens,
temperature=request.temperature,
stop=request.stop,
)
def handle(self, batch: List[PromptRequest]) -> List:
responses = []
for request in batch:
output = self(request)
responses.extend(output)

result = llm.generate(request.prompt, sampling_params=sampling_params)[0]
response = ResponseModel(
prompt=request.prompt,
prompt_token_ids=result.prompt_token_ids,
outputs=result.outputs,
finished=result.finished,
)
return response
return responses
24 changes: 24 additions & 0 deletions src/fastserve/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,3 +23,27 @@ def get_ui_folder():
path = os.path.join(os.path.dirname(__file__), "../ui")
path = os.path.abspath(path)
return path


def download_file(url: str, dest: str):
import requests
from tqdm import tqdm

if dest is None:
dest = os.path.abspath(os.path.basename(dest))

response = requests.get(url, stream=True)
response.raise_for_status()
total_size = int(response.headers.get("content-length", 0))
block_size = 1024
with open(dest, "wb") as file, tqdm(
desc=dest,
total=total_size,
unit="iB",
unit_scale=True,
unit_divisor=1024,
) as bar:
for data in response.iter_content(block_size):
file.write(data)
bar.update(len(data))
return dest

0 comments on commit a48c67e

Please sign in to comment.