From a48c67ef721179e8032b79a4439a8a07f6f0954b Mon Sep 17 00:00:00 2001 From: Aniket Maurya Date: Fri, 23 Feb 2024 01:06:51 +0000 Subject: [PATCH] add vLLM (#21) * update * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * update * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * update * update * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --------- Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> --- README.md | 26 +++++++++++- src/fastserve/client/__init__.py | 0 src/fastserve/client/vllm.py | 47 ++++++++++++++++++++ src/fastserve/models/__init__.py | 1 + src/fastserve/models/vllm.py | 73 ++++++++++++++++++++------------ src/fastserve/utils.py | 24 +++++++++++ 6 files changed, 143 insertions(+), 28 deletions(-) create mode 100644 src/fastserve/client/__init__.py create mode 100644 src/fastserve/client/vllm.py diff --git a/README.md b/README.md index a3c6247..1acc252 100644 --- a/README.md +++ b/README.md @@ -26,7 +26,8 @@ python -m fastserve ## Usage/Examples -### Serve Mistral-7B with Llama-cpp + +### Serve LLMs with Llama-cpp ```python from fastserve.models import ServeLlamaCpp @@ -38,6 +39,29 @@ serve.run_server() or, run `python -m fastserve.models --model llama-cpp --model_path openhermes-2-mistral-7b.Q5_K_M.gguf` from terminal. + +### Serve vLLM + +```python +from fastserve.models import ServeVLLM + +app = ServeVLLM("TinyLlama/TinyLlama-1.1B-Chat-v1.0") +app.run_server() +``` + +You can use the FastServe client that will automatically apply chat template for you - + +```python +from fastserve.client import vLLMClient +from rich import print + +client = vLLMClient("TinyLlama/TinyLlama-1.1B-Chat-v1.0") +response = client.chat("Write a python function to resize image to 224x224", keep_context=True) +# print(client.context) +print(response["outputs"][0]["text"]) +``` + + ### Serve SDXL Turbo ```python diff --git a/src/fastserve/client/__init__.py b/src/fastserve/client/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/src/fastserve/client/vllm.py b/src/fastserve/client/vllm.py new file mode 100644 index 0000000..73dcc18 --- /dev/null +++ b/src/fastserve/client/vllm.py @@ -0,0 +1,47 @@ +import logging + +import requests + + +class Client: + def __init__(self): + pass + + +class vLLMClient(Client): + def __init__(self, model: str, base_url="http://localhost:8000/endpoint"): + from transformers import AutoTokenizer + + super().__init__() + self.tokenizer = AutoTokenizer.from_pretrained(model) + self.context = [] + self.base_url = base_url + + def chat(self, prompt: str, keep_context=False): + new_msg = {"role": "user", "content": prompt} + if keep_context: + self.context.append(new_msg) + messages = self.context + else: + messages = [new_msg] + + logging.info(messages) + chat = self.tokenizer.apply_chat_template(messages, tokenize=False) + headers = { + "accept": "application/json", + "Content-Type": "application/json", + } + data = { + "prompt": chat, + "temperature": 0.8, + "top_p": 1, + "max_tokens": 500, + "stop": [], + } + + response = requests.post(self.base_url, headers=headers, json=data).json() + if keep_context: + self.context.append( + {"role": "assistant", "content": response["outputs"][0]["text"]} + ) + return response diff --git a/src/fastserve/models/__init__.py b/src/fastserve/models/__init__.py index 3c7d565..8c495da 100644 --- a/src/fastserve/models/__init__.py +++ b/src/fastserve/models/__init__.py @@ -5,3 +5,4 @@ from fastserve.models.llama_cpp import ServeLlamaCpp as ServeLlamaCpp from fastserve.models.sdxl_turbo import ServeSDXLTurbo as ServeSDXLTurbo from fastserve.models.ssd import ServeSSD1B as ServeSSD1B +from fastserve.models.vllm import ServeVLLM as ServeVLLM diff --git a/src/fastserve/models/vllm.py b/src/fastserve/models/vllm.py index 3f57fe2..db89fe0 100644 --- a/src/fastserve/models/vllm.py +++ b/src/fastserve/models/vllm.py @@ -1,46 +1,65 @@ -import os -from typing import List +import logging +from typing import Any, List, Optional -from fastapi import FastAPI from pydantic import BaseModel -from vllm import LLM, SamplingParams -tensor_parallel_size = int(os.environ.get("DEVICES", "1")) -print("tensor_parallel_size: ", tensor_parallel_size) +from fastserve.core import FastServe -llm = LLM("meta-llama/Llama-2-7b-hf", tensor_parallel_size=tensor_parallel_size) +logger = logging.getLogger(__name__) class PromptRequest(BaseModel): - prompt: str - temperature: float = 1 + prompt: str = "Write a python function to resize image to 224x224" + temperature: float = 0.8 + top_p: float = 1.0 max_tokens: int = 200 stop: List[str] = [] class ResponseModel(BaseModel): prompt: str - prompt_token_ids: List # The token IDs of the prompt. - outputs: List[str] # The output sequences of the request. + prompt_token_ids: Optional[List] = None # The token IDs of the prompt. + text: str # The output sequences of the request. finished: bool # Whether the whole request is finished. -app = FastAPI() +class ServeVLLM(FastServe): + def __init__( + self, + model, + batch_size=1, + timeout=0.0, + *args, + **kwargs, + ): + from vllm import LLM + + self.llm = LLM(model) + self.args = args + self.kwargs = kwargs + super().__init__( + batch_size, + timeout, + input_schema=PromptRequest, + # response_schema=ResponseModel, + ) + + def __call__(self, request: PromptRequest) -> Any: + from vllm import SamplingParams + sampling_params = SamplingParams( + temperature=request.temperature, + top_p=request.top_p, + max_tokens=request.max_tokens, + ) + result = self.llm.generate(request.prompt, sampling_params=sampling_params) + logger.info(result) + return result -@app.post("/serve", response_model=ResponseModel) -def serve(request: PromptRequest): - sampling_params = SamplingParams( - max_tokens=request.max_tokens, - temperature=request.temperature, - stop=request.stop, - ) + def handle(self, batch: List[PromptRequest]) -> List: + responses = [] + for request in batch: + output = self(request) + responses.extend(output) - result = llm.generate(request.prompt, sampling_params=sampling_params)[0] - response = ResponseModel( - prompt=request.prompt, - prompt_token_ids=result.prompt_token_ids, - outputs=result.outputs, - finished=result.finished, - ) - return response + return responses diff --git a/src/fastserve/utils.py b/src/fastserve/utils.py index 55deeb6..0dc2c59 100644 --- a/src/fastserve/utils.py +++ b/src/fastserve/utils.py @@ -23,3 +23,27 @@ def get_ui_folder(): path = os.path.join(os.path.dirname(__file__), "../ui") path = os.path.abspath(path) return path + + +def download_file(url: str, dest: str): + import requests + from tqdm import tqdm + + if dest is None: + dest = os.path.abspath(os.path.basename(dest)) + + response = requests.get(url, stream=True) + response.raise_for_status() + total_size = int(response.headers.get("content-length", 0)) + block_size = 1024 + with open(dest, "wb") as file, tqdm( + desc=dest, + total=total_size, + unit="iB", + unit_scale=True, + unit_divisor=1024, + ) as bar: + for data in response.iter_content(block_size): + file.write(data) + bar.update(len(data)) + return dest