Skip to content

Commit

Permalink
✨ Prompt adapter support
Browse files Browse the repository at this point in the history
Signed-off-by: Prashant Gupta <prashantgupta@us.ibm.com>
  • Loading branch information
prashantgupta24 authored and dtrifiro committed Jul 11, 2024
1 parent 4387823 commit d152993
Show file tree
Hide file tree
Showing 4 changed files with 69 additions and 15 deletions.
12 changes: 11 additions & 1 deletion src/vllm_tgis_adapter/__main__.py
Original file line number Diff line number Diff line change
Expand Up @@ -202,8 +202,18 @@ async def run_http_server(
args.lora_modules,
args.chat_template,
)

kwargs = {}
# prompt adapter arg required for vllm >0.5.1
if hasattr(args, "prompt_adapters"):
kwargs = {"prompt_adapters": args.prompt_adapters}

openai_serving_completion = OpenAIServingCompletion(
engine, model_config, served_model_names, args.lora_modules
engine,
model_config,
served_model_names,
args.lora_modules,
**kwargs,
)
openai_serving_embedding = OpenAIServingEmbedding(
engine, model_config, served_model_names
Expand Down
34 changes: 26 additions & 8 deletions src/vllm_tgis_adapter/grpc/adapters.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,14 +14,16 @@
from pathlib import Path
from typing import TYPE_CHECKING

from vllm.lora.request import LoRARequest
from vllm.prompt_adapter.request import PromptAdapterRequest

from .validation import TGISValidationError

if TYPE_CHECKING:
from vllm.entrypoints.grpc.pb.generation_pb2 import (
BatchedGenerationRequest,
SingleGenerationRequest,
)
from vllm.lora.request import LoRARequest

from .validation import TGISValidationError

global_thread_pool = None # used for loading adapter files from disk

Expand All @@ -33,6 +35,7 @@ class AdapterMetadata:
unique_id: int # Unique integer for vllm to identify the adapter
adapter_type: str # The string name of the peft adapter type, e.g. LORA
full_path: str
full_config: dict # The loaded adapter_config.json dict


@dataclasses.dataclass
Expand All @@ -45,7 +48,7 @@ class AdapterStore:
async def validate_adapters(
request: SingleGenerationRequest | BatchedGenerationRequest,
adapter_store: AdapterStore | None,
) -> dict[str, LoRARequest]:
) -> dict[str, LoRARequest | PromptAdapterRequest]:
"""Validate the adapters.
Takes the adapter name from the request and constructs a valid
Expand All @@ -56,6 +59,9 @@ async def validate_adapters(
"""
global global_thread_pool # noqa: PLW0603
adapter_id = request.adapter_id
# Backwards compatibility for `prefix_id` arg
if not adapter_id and request.prefix_id:
adapter_id = request.prefix_id

if adapter_id and not adapter_store:
TGISValidationError.AdaptersDisabled.error()
Expand All @@ -73,18 +79,20 @@ async def validate_adapters(
if global_thread_pool is None:
global_thread_pool = concurrent.futures.ThreadPoolExecutor(max_workers=2)

adapter_type = await loop.run_in_executor(
adapter_config = await loop.run_in_executor(
global_thread_pool,
_get_adapter_type_from_file,
_load_adapter_config_from_file,
adapter_id,
local_adapter_path,
)
adapter_type = adapter_config.get("peft_type", None)

# Add to cache
adapter_metadata = AdapterMetadata(
unique_id=adapter_store.next_unique_id,
adapter_type=adapter_type,
full_path=local_adapter_path,
full_config=adapter_config,
)
adapter_store.adapters[adapter_id] = adapter_metadata

Expand All @@ -96,12 +104,22 @@ async def validate_adapters(
lora_local_path=adapter_metadata.full_path,
)
return {"lora_request": lora_request}
if adapter_metadata.adapter_type == "PROMPT_TUNING":
prompt_adapter_request = PromptAdapterRequest(
prompt_adapter_id=adapter_metadata.unique_id,
prompt_adapter_name=adapter_id,
prompt_adapter_local_path=adapter_metadata.full_path,
prompt_adapter_num_virtual_tokens=adapter_metadata.full_config.get(
"num_virtual_tokens", 0
),
)
return {"prompt_adapter_request": prompt_adapter_request}

# All other types unsupported
TGISValidationError.AdapterUnsupported.error(adapter_metadata.adapter_type) # noqa: RET503


def _get_adapter_type_from_file(adapter_id: str, adapter_path: str) -> str:
def _load_adapter_config_from_file(adapter_id: str, adapter_path: str) -> dict:
"""Get adapter from file.
Performs all the filesystem access required to deduce the type
Expand All @@ -123,7 +141,7 @@ def _get_adapter_type_from_file(adapter_id: str, adapter_path: str) -> str:
with open(adapter_config_path) as adapter_config_file:
adapter_config = json.load(adapter_config_file)

return adapter_config.get("peft_type", None)
return adapter_config


def _reject_bad_adapter_id(adapter_id: str) -> None:
Expand Down
34 changes: 28 additions & 6 deletions src/vllm_tgis_adapter/grpc/grpc_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,6 @@
TGISStatLogger,
)

from .adapters import AdapterStore, validate_adapters
from .pb import generation_pb2_grpc
from .pb.generation_pb2 import DESCRIPTOR as _GENERATION_DESCRIPTOR
from .pb.generation_pb2 import (
Expand All @@ -56,6 +55,14 @@
)
from .validation import validate_input, validate_params

try:
from .adapters import AdapterStore, validate_adapters
except ImportError:
adapters_available = False
else:
adapters_available = True


if TYPE_CHECKING:
import argparse
from collections.abc import AsyncIterator, MutableSequence
Expand All @@ -76,6 +83,11 @@
SingleGenerationRequest,
)

try:
from .adapters import PromptAdapterRequest
except ImportError:
pass

_T = TypeVar("_T")
_F = TypeVar("_F", Callable, Coroutine)

Expand Down Expand Up @@ -170,9 +182,11 @@ def __init__(
self.skip_special_tokens = not args.output_special_tokens
self.default_include_stop_seqs = args.default_include_stop_seqs

# Backwards compatibility for TGIS: PREFIX_STORE_PATH
adapter_cache_path = args.adapter_cache or args.prefix_store_path
self.adapter_store = (
AdapterStore(cache_path=args.adapter_cache, adapters={})
if args.adapter_cache
AdapterStore(cache_path=adapter_cache_path, adapters={})
if adapter_cache_path
else None
)
self.health_servicer = health_servicer
Expand Down Expand Up @@ -213,7 +227,11 @@ async def Generate(
generators = []
max_is_token_limit = [False] * request_count

adapter_kwargs = await self._validate_adapters(request, context)
adapter_kwargs = (
await self._validate_adapters(request, context)
if adapters_available
else {}
)

for i, req in enumerate(request.requests):
input_ids, max_is_token_limit[i] = await self._validate_prompt_and_tokenize(
Expand Down Expand Up @@ -309,7 +327,11 @@ async def GenerateStream(
sampling_params, truncate_input_tokens, request.request.text, context
)

adapter_kwargs = await self._validate_adapters(request, context)
adapter_kwargs = (
await self._validate_adapters(request, context)
if adapters_available
else {}
)
inputs = TextTokensPrompt(
prompt=request.request.text, prompt_token_ids=input_ids
)
Expand Down Expand Up @@ -577,7 +599,7 @@ async def _validate_adapters(
self,
request: SingleGenerationRequest | BatchedGenerationRequest,
context: ServicerContext,
) -> dict[str, LoRARequest]:
) -> dict[str, LoRARequest | PromptAdapterRequest]:
try:
adapters = await validate_adapters(
request=request, adapter_store=self.adapter_store
Expand Down
4 changes: 4 additions & 0 deletions src/vllm_tgis_adapter/tgis_utils/args.py
Original file line number Diff line number Diff line change
Expand Up @@ -116,6 +116,10 @@ def add_tgis_args(parser: argparse.ArgumentParser) -> argparse.ArgumentParser:
parser.add_argument("--tls-client-ca-cert-path", type=str)
# add a path when peft adapters will be loaded from
parser.add_argument("--adapter-cache", type=str)
# backwards-compatibility support for tgis prompt tuning
parser.add_argument(
"--prefix-store-path", type=str, help="Deprecated, use --adapter-cache"
)

# TODO check/add other args here

Expand Down

0 comments on commit d152993

Please sign in to comment.