Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

chore: use the LoRA tokenizer in OpenAI API #599

Merged
merged 1 commit into from
Aug 24, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion aphrodite/endpoints/openai/api_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -528,7 +528,8 @@ def run_server(args, llm_engine=None):
openai_serving_embedding = OpenAIServingEmbedding(engine, model_config,
served_model_names)
openai_serving_tokenization = OpenAIServingTokenization(
engine, model_config, served_model_names, args.chat_template)
engine, model_config, served_model_names, args.lora_modules,
args.chat_template)
app.root_path = args.root_path

tokenizer = get_tokenizer(
Expand Down
90 changes: 43 additions & 47 deletions aphrodite/endpoints/openai/chat_utils.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import codecs
import tempfile
from dataclasses import dataclass, field
from functools import lru_cache
from typing import Awaitable, Iterable, List, Optional, TypedDict, cast, final
Expand All @@ -7,10 +8,11 @@
from loguru import logger
from openai.types.chat import (ChatCompletionContentPartImageParam,
ChatCompletionContentPartTextParam)
from transformers import PreTrainedTokenizer

from aphrodite.common.config import ModelConfig
from aphrodite.endpoints.openai.protocol import (
ChatCompletionContentPartParam, ChatCompletionMessageParam)
from aphrodite.endpoints.openai.serving_engine import OpenAIServing
from aphrodite.multimodal import MultiModalDataDict
from aphrodite.multimodal.utils import async_get_and_parse_image

Expand All @@ -28,64 +30,56 @@ class ChatMessageParseResult:
default_factory=list)


def load_chat_template(self, chat_template: Optional[str]):
tokenizer = self.tokenizer

if chat_template is not None:
try:
if chat_template.startswith('http'):
response = requests.get(chat_template)
if response.status_code == 200:
tokenizer.chat_template = response.text
else:
raise ValueError("Failed to download chat template "
f"from {chat_template}")
else:
with open(chat_template, "r") as f:
tokenizer.chat_template = f.read()
except OSError as e:
JINJA_CHARS = "{}\n"
if not any(c in chat_template for c in JINJA_CHARS):
msg = (f"The supplied chat template ({chat_template}) "
f"looks like a file path, but it failed to be "
f"opened. Reason: {e}")
raise ValueError(msg) from e

# If opening a file fails, set chat template to be args to
# ensure we decode so our escape are interpreted correctly
tokenizer.chat_template = codecs.decode(chat_template,
"unicode_escape")

logger.info("Using supplied chat template")
elif tokenizer.chat_template is not None:
logger.info("Using default chat template")
else:
logger.warning("No chat template provided. Chat API will not work.")
def load_chat_template(chat_template: Optional[str]) -> Optional[str]:
if chat_template is None:
return None
try:
if chat_template.startswith(('http')):
response = requests.get(chat_template)
temp = tempfile.NamedTemporaryFile(delete=False)
temp.write(response.content)
temp.close()
chat_template = temp.name

with open(chat_template, "r") as f:
resolved_chat_template = f.read()
except OSError as e:
JINJA_CHARS = "{}\n"
if not any(c in chat_template for c in JINJA_CHARS):
msg = (f"The supplied chat template ({chat_template}) "
"looks like a file path, but it failed to be "
f"opened. Reason: {e}")
raise ValueError(msg) from e

# If opening a file fails, set chat template to be args to
# ensure we decode so our escape are interpreted correctly
resolved_chat_template = codecs.decode(chat_template, "unicode_escape")

logger.info(f"Using supplied chat template:\n{resolved_chat_template}")
return resolved_chat_template


@lru_cache(maxsize=None)
def _image_token_str(engine: OpenAIServing) -> Optional[str]:
def _image_token_str(model_config: ModelConfig,
tokenizer: PreTrainedTokenizer) -> Optional[str]:
# TODO: Let user specify how to insert image tokens into prompt
# (similar to chat template)
model_type = engine.model_config.hf_config.model_type
model_type = model_config.hf_config.model_type
if model_type == "phi3_v":
# Workaround since this token is not defined in the tokenizer
return "<|image_1|>"
if model_type in ("blip-2", "chatglm", "fuyu", "minicpmv", "paligemma"):
# These models do not use image tokens in the prompt
return None
if model_type.startswith("llava"):
return engine.tokenizer.decode(
engine.model_config.hf_config.image_token_index)
return tokenizer.decode(model_config.hf_config.image_token_index)

else:
raise TypeError(f"Unknown model type: {model_type}")
raise TypeError(f"Unknown model type: {model_type}")


# TODO: Let user specify how to insert image tokens into prompt
# (similar to chat template)
def _get_full_image_text_prompt(engine: OpenAIServing, image_token_str: str,
text_prompt: str) -> str:
def _get_full_image_text_prompt(image_token_str: str, text_prompt: str) -> str:
"""Combine image and text prompts for vision language model"""

# NOTE: For now we assume all model architectures use the same
Expand All @@ -94,9 +88,10 @@ def _get_full_image_text_prompt(engine: OpenAIServing, image_token_str: str,


def _parse_chat_message_content_parts(
engine: OpenAIServing,
role: str,
parts: Iterable[ChatCompletionContentPartParam],
model_config: ModelConfig,
tokenizer: PreTrainedTokenizer,
) -> ChatMessageParseResult:
texts: List[str] = []
mm_futures: List[Awaitable[MultiModalDataDict]] = []
Expand Down Expand Up @@ -127,15 +122,14 @@ def _parse_chat_message_content_parts(
text_prompt = "\n".join(texts)

if mm_futures:
image_token_str = _image_token_str(engine)
image_token_str = _image_token_str(model_config, tokenizer)
if image_token_str is not None:
if image_token_str in text_prompt:
logger.warning(
"Detected image token string in the text prompt. "
"Skipping prompt formatting.")
else:
text_prompt = _get_full_image_text_prompt(
engine,
image_token_str=image_token_str,
text_prompt=text_prompt,
)
Expand All @@ -146,8 +140,9 @@ def _parse_chat_message_content_parts(


def parse_chat_message_content(
engine: OpenAIServing,
message: ChatCompletionMessageParam,
model_config: ModelConfig,
tokenizer: PreTrainedTokenizer,
) -> ChatMessageParseResult:
role = message["role"]
content = message.get("content")
Expand All @@ -158,4 +153,5 @@ def parse_chat_message_content(
messages = [ConversationMessage(role=role, content=content)]
return ChatMessageParseResult(messages=messages, mm_futures=[])

return _parse_chat_message_content_parts(engine, role, content)
return _parse_chat_message_content_parts(role, content, model_config,
tokenizer)
2 changes: 2 additions & 0 deletions aphrodite/endpoints/openai/protocol.py
Original file line number Diff line number Diff line change
Expand Up @@ -765,6 +765,7 @@ class BatchRequestOutput(OpenAIBaseModel):


class TokenizeRequest(OpenAIBaseModel):
model: Optional[str]
add_generation_prompt: bool = Field(default=True)
add_special_tokens: bool = Field(default=False)
prompt: Optional[str] = Field(default=None)
Expand All @@ -778,6 +779,7 @@ class TokenizeResponse(OpenAIBaseModel):


class DetokenizeRequest(OpenAIBaseModel):
model: Optional[str]
tokens: List[int]


Expand Down
71 changes: 42 additions & 29 deletions aphrodite/endpoints/openai/serving_chat.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@

from fastapi import Request
from loguru import logger
from transformers import PreTrainedTokenizer

from aphrodite.common.config import ModelConfig
from aphrodite.common.outputs import RequestOutput
Expand Down Expand Up @@ -45,7 +46,8 @@ def __init__(self,
lora_modules=lora_modules)

self.response_role = response_role
load_chat_template(self, chat_template)
# If this is None we use the tokenizer's default chat template
self.chat_template = load_chat_template(chat_template)

async def create_chat_completion(
self,
Expand All @@ -67,11 +69,14 @@ async def create_chat_completion(
return error_check_ret

try:
_, lora_request = self._maybe_get_adapter(request)
tokenizer = await self.engine.get_tokenizer(lora_request)
conversation: List[ConversationMessage] = []
mm_futures: List[Awaitable[MultiModalDataDict]] = []

for msg in request.messages:
chat_parsed_result = parse_chat_message_content(self, msg)
chat_parsed_result = parse_chat_message_content(
msg, self.model_config, tokenizer)

conversation.extend(chat_parsed_result.messages)
mm_futures.extend(chat_parsed_result.mm_futures)
Expand All @@ -80,13 +85,13 @@ async def create_chat_completion(
tool.model_dump() for tool in request.tools
]

prompt = self.tokenizer.apply_chat_template(
prompt = tokenizer.apply_chat_template(
conversation=conversation,
tokenize=False,
add_generation_prompt=request.add_generation_prompt,
tools=tool_dicts,
documents=request.documents,
chat_template=request.chat_template,
chat_template=request.chat_template or self.chat_template,
**(request.chat_template_kwargs or {}),
)
except Exception as e:
Expand All @@ -108,19 +113,19 @@ async def create_chat_completion(
request_id = f"cmpl-{random_uuid()}"
try:
# Tokenize/detokenize depending on prompt format (string/token list)
prompt_ids, prompt_text = self._validate_prompt_and_tokenize(
prompt_ids, prompt_text = await self._validate_prompt_and_tokenize(
request,
tokenizer,
prompt=prompt,
add_special_tokens=request.add_special_tokens)
sampling_params = request.to_sampling_params()
_, lora_request = self._maybe_get_adapter(request)
decoding_config = await self.engine.get_decoding_config()
guided_decoding_backend = request.guided_decoding_backend \
or decoding_config.guided_decoding_backend
guided_decode_logits_processor = (
await get_guided_decoding_logits_processor(
guided_decoding_backend, request, await
self.engine.get_tokenizer()))
await
get_guided_decoding_logits_processor(guided_decoding_backend,
request, tokenizer))
if guided_decode_logits_processor:
if sampling_params.logits_processors is None:
sampling_params.logits_processors = []
Expand All @@ -145,12 +150,12 @@ async def create_chat_completion(
# Streaming response
if request.stream:
return self.chat_completion_stream_generator(
request, result_generator, request_id, conversation)
request, result_generator, request_id, conversation, tokenizer)
else:
try:
return await self.chat_completion_full_generator(
request, raw_request, result_generator, request_id,
conversation)
conversation, tokenizer)
except ValueError as e:
# TODO: Use an aphrodite-specific Validation Error
return self.create_error_response(str(e))
Expand All @@ -162,9 +167,12 @@ def get_chat_request_role(self, request: ChatCompletionRequest) -> str:
return request.messages[-1]["role"]

async def chat_completion_stream_generator(
self, request: ChatCompletionRequest,
result_generator: AsyncIterator[RequestOutput], request_id: str,
conversation: List[ConversationMessage]
self,
request: ChatCompletionRequest,
result_generator: AsyncIterator[RequestOutput],
request_id: str,
conversation: List[ConversationMessage],
tokenizer: PreTrainedTokenizer,
) -> AsyncGenerator[str, None]:
model_name = self.served_model_names[0]
created_time = int(time.time())
Expand Down Expand Up @@ -251,6 +259,7 @@ async def chat_completion_stream_generator(
logprobs = self._create_chat_logprobs(
token_ids=delta_token_ids,
top_logprobs=out_logprobs,
tokenizer=tokenizer,
num_output_top_logprobs=request.top_logprobs,
)
else:
Expand Down Expand Up @@ -339,9 +348,13 @@ async def chat_completion_stream_generator(
yield "data: [DONE]\n\n"

async def chat_completion_full_generator(
self, request: ChatCompletionRequest, raw_request: Optional[Request],
result_generator: AsyncIterator[RequestOutput], request_id: str,
conversation: List[ConversationMessage]
self,
request: ChatCompletionRequest,
raw_request: Optional[Request],
result_generator: AsyncIterator[RequestOutput],
request_id: str,
conversation: List[ConversationMessage],
tokenizer: PreTrainedTokenizer,
) -> Union[ErrorResponse, ChatCompletionResponse]:

model_name = self.served_model_names[0]
Expand All @@ -368,6 +381,7 @@ async def chat_completion_full_generator(
logprobs = self._create_chat_logprobs(
token_ids=token_ids,
top_logprobs=out_logprobs,
tokenizer=tokenizer,
num_output_top_logprobs=request.top_logprobs,
)
else:
Expand Down Expand Up @@ -423,16 +437,14 @@ async def chat_completion_full_generator(
return response

def _get_top_logprobs(
self, logprobs: Dict[int, Logprob],
top_logprobs: Optional[int]) -> List[ChatCompletionLogProb]:
self, logprobs: Dict[int, Logprob], top_logprobs: Optional[int],
tokenizer: PreTrainedTokenizer) -> List[ChatCompletionLogProb]:
return [
ChatCompletionLogProb(
token=self._get_decoded_token(p[1], p[0]),
token=(token := self._get_decoded_token(p[1], p[0],
tokenizer)),
logprob=max(p[1].logprob, -9999.0),
bytes=list(
self._get_decoded_token(p[1],
p[0]).encode("utf-8",
errors="replace")))
bytes=list(token.encode("utf-8", errors="replace")))
for i, p in enumerate(logprobs.items())
if top_logprobs and i < top_logprobs
]
Expand All @@ -441,6 +453,7 @@ def _create_chat_logprobs(
self,
token_ids: GenericSequence[int],
top_logprobs: GenericSequence[Optional[Dict[int, Logprob]]],
tokenizer: PreTrainedTokenizer,
num_output_top_logprobs: Optional[int] = None,
) -> ChatCompletionLogProbs:
"""Create OpenAI-style logprobs."""
Expand All @@ -450,12 +463,11 @@ def _create_chat_logprobs(
for i, token_id in enumerate(token_ids):
step_top_logprobs = top_logprobs[i]
if step_top_logprobs is None:
token = tokenizer.decode(token_id)
logprobs_content.append(
ChatCompletionLogProbsContent(
token=self.tokenizer.decode(token_id),
bytes=list(
self.tokenizer.decode(token_id).encode(
"utf-8", errors="replace"))))
token=token,
bytes=list(token.encode("utf-8", errors="replace"))))
else:
logprobs_content.append(
ChatCompletionLogProbsContent(
Expand All @@ -466,6 +478,7 @@ def _create_chat_logprobs(
step_top_logprobs[token_id].decoded_token.encode(
"utf-8", errors="replace")),
top_logprobs=self._get_top_logprobs(
step_top_logprobs, num_output_top_logprobs)))
step_top_logprobs, num_output_top_logprobs,
tokenizer)))

return ChatCompletionLogProbs(content=logprobs_content)
Loading