Skip to content

Commit

Permalink
✨ Prompt adapter support
Browse files Browse the repository at this point in the history
Signed-off-by: Prashant Gupta <prashantgupta@us.ibm.com>
  • Loading branch information
prashantgupta24 authored and dtrifiro committed Jul 10, 2024
1 parent 4471c23 commit 0e13eac
Show file tree
Hide file tree
Showing 4 changed files with 43 additions and 9 deletions.
12 changes: 11 additions & 1 deletion src/vllm_tgis_adapter/__main__.py
Original file line number Diff line number Diff line change
Expand Up @@ -202,8 +202,18 @@ async def run_http_server(
args.lora_modules,
args.chat_template,
)

kwargs = {}
# prompt adapter arg required for vllm >0.5.1
if hasattr(args, "prompt_adapters"):
kwargs = {"prompt_adapters": args.prompt_adapters}

openai_serving_completion = OpenAIServingCompletion(
engine, model_config, served_model_names, args.lora_modules
engine,
model_config,
served_model_names,
args.lora_modules,
**kwargs,
)
openai_serving_embedding = OpenAIServingEmbedding(
engine, model_config, served_model_names
Expand Down
27 changes: 22 additions & 5 deletions src/vllm_tgis_adapter/grpc/adapters.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
SingleGenerationRequest,
)
from vllm.lora.request import LoRARequest
from vllm.prompt_adapter.request import PromptAdapterRequest

from .validation import TGISValidationError

Expand All @@ -33,6 +34,7 @@ class AdapterMetadata:
unique_id: int # Unique integer for vllm to identify the adapter
adapter_type: str # The string name of the peft adapter type, e.g. LORA
full_path: str
full_config: dict # The loaded adapter_config.json dict


@dataclasses.dataclass
Expand All @@ -45,7 +47,7 @@ class AdapterStore:
async def validate_adapters(
request: SingleGenerationRequest | BatchedGenerationRequest,
adapter_store: AdapterStore | None,
) -> dict[str, LoRARequest]:
) -> dict[str, LoRARequest | PromptAdapterRequest]:
"""Validate the adapters.
Takes the adapter name from the request and constructs a valid
Expand All @@ -56,6 +58,9 @@ async def validate_adapters(
"""
global global_thread_pool # noqa: PLW0603
adapter_id = request.adapter_id
# Backwards compatibility for `prefix_id` arg
if not adapter_id and request.prefix_id:
adapter_id = request.prefix_id

if adapter_id and not adapter_store:
TGISValidationError.AdaptersDisabled.error()
Expand All @@ -73,18 +78,20 @@ async def validate_adapters(
if global_thread_pool is None:
global_thread_pool = concurrent.futures.ThreadPoolExecutor(max_workers=2)

adapter_type = await loop.run_in_executor(
adapter_config = await loop.run_in_executor(
global_thread_pool,
_get_adapter_type_from_file,
_load_adapter_config_from_file,
adapter_id,
local_adapter_path,
)
adapter_type = adapter_config.get("peft_type", None)

# Add to cache
adapter_metadata = AdapterMetadata(
unique_id=adapter_store.next_unique_id,
adapter_type=adapter_type,
full_path=local_adapter_path,
full_config=adapter_config,
)
adapter_store.adapters[adapter_id] = adapter_metadata

Expand All @@ -96,12 +103,22 @@ async def validate_adapters(
lora_local_path=adapter_metadata.full_path,
)
return {"lora_request": lora_request}
if adapter_metadata.adapter_type == "PROMPT_TUNING":
prompt_adapter_request = PromptAdapterRequest(
prompt_adapter_id=adapter_metadata.unique_id,
prompt_adapter_name=adapter_id,
prompt_adapter_local_path=adapter_metadata.full_path,
prompt_adapter_num_virtual_tokens=adapter_metadata.full_config.get(
"num_virtual_tokens", 0
),
)
return {"prompt_adapter_request": prompt_adapter_request}

# All other types unsupported
TGISValidationError.AdapterUnsupported.error(adapter_metadata.adapter_type) # noqa: RET503


def _get_adapter_type_from_file(adapter_id: str, adapter_path: str) -> str:
def _load_adapter_config_from_file(adapter_id: str, adapter_path: str) -> dict:
"""Get adapter from file.
Performs all the filesystem access required to deduce the type
Expand All @@ -123,7 +140,7 @@ def _get_adapter_type_from_file(adapter_id: str, adapter_path: str) -> str:
with open(adapter_config_path) as adapter_config_file:
adapter_config = json.load(adapter_config_file)

return adapter_config.get("peft_type", None)
return adapter_config


def _reject_bad_adapter_id(adapter_id: str) -> None:
Expand Down
9 changes: 6 additions & 3 deletions src/vllm_tgis_adapter/grpc/grpc_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
from vllm.engine.async_llm_engine import _AsyncLLMEngine
from vllm.entrypoints.openai.serving_completion import merge_async_iterators
from vllm.inputs import TextTokensPrompt
from vllm.prompt_adapter.request import PromptAdapterRequest # noqa: TCH002

from vllm_tgis_adapter.logging import init_logger
from vllm_tgis_adapter.tgis_utils import logs
Expand Down Expand Up @@ -177,9 +178,11 @@ def __init__(
self.skip_special_tokens = not args.output_special_tokens
self.default_include_stop_seqs = args.default_include_stop_seqs

# Backwards compatibility for TGIS: PREFIX_STORE_PATH
adapter_cache_path = args.adapter_cache or args.prefix_store_path
self.adapter_store = (
AdapterStore(cache_path=args.adapter_cache, adapters={})
if args.adapter_cache
AdapterStore(cache_path=adapter_cache_path, adapters={})
if adapter_cache_path
else None
)
self.health_servicer = health_servicer
Expand Down Expand Up @@ -585,7 +588,7 @@ async def _validate_adapters(
self,
request: SingleGenerationRequest | BatchedGenerationRequest,
context: ServicerContext,
) -> dict[str, LoRARequest]:
) -> dict[str, LoRARequest | PromptAdapterRequest]:
try:
adapters = await validate_adapters(
request=request, adapter_store=self.adapter_store
Expand Down
4 changes: 4 additions & 0 deletions src/vllm_tgis_adapter/tgis_utils/args.py
Original file line number Diff line number Diff line change
Expand Up @@ -116,6 +116,10 @@ def add_tgis_args(parser: argparse.ArgumentParser) -> argparse.ArgumentParser:
parser.add_argument("--tls-client-ca-cert-path", type=str)
# add a path when peft adapters will be loaded from
parser.add_argument("--adapter-cache", type=str)
# backwards-compatibility support for tgis prompt tuning
parser.add_argument(
"--prefix-store-path", type=str, help="Deprecated, use --adapter-cache"
)

# TODO check/add other args here

Expand Down

0 comments on commit 0e13eac

Please sign in to comment.