Skip to content

Commit

Permalink
chore: use the LoRA tokenizer in OpenAI API (PygmalionAI#599)
Browse files Browse the repository at this point in the history
  • Loading branch information
AlpinDale authored and 50h100a committed Sep 1, 2024
1 parent ff5f772 commit be4438d
Show file tree
Hide file tree
Showing 12 changed files with 179 additions and 138 deletions.
3 changes: 2 additions & 1 deletion aphrodite/endpoints/openai/api_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -528,7 +528,8 @@ def run_server(args, llm_engine=None):
openai_serving_embedding = OpenAIServingEmbedding(engine, model_config,
served_model_names)
openai_serving_tokenization = OpenAIServingTokenization(
engine, model_config, served_model_names, args.chat_template)
engine, model_config, served_model_names, args.lora_modules,
args.chat_template)
app.root_path = args.root_path

tokenizer = get_tokenizer(
Expand Down
90 changes: 43 additions & 47 deletions aphrodite/endpoints/openai/chat_utils.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import codecs
import tempfile
from dataclasses import dataclass, field
from functools import lru_cache
from typing import Awaitable, Iterable, List, Optional, TypedDict, cast, final
Expand All @@ -7,10 +8,11 @@
from loguru import logger
from openai.types.chat import (ChatCompletionContentPartImageParam,
ChatCompletionContentPartTextParam)
from transformers import PreTrainedTokenizer

from aphrodite.common.config import ModelConfig
from aphrodite.endpoints.openai.protocol import (
ChatCompletionContentPartParam, ChatCompletionMessageParam)
from aphrodite.endpoints.openai.serving_engine import OpenAIServing
from aphrodite.multimodal import MultiModalDataDict
from aphrodite.multimodal.utils import async_get_and_parse_image

Expand All @@ -28,64 +30,56 @@ class ChatMessageParseResult:
default_factory=list)


def load_chat_template(self, chat_template: Optional[str]):
tokenizer = self.tokenizer

if chat_template is not None:
try:
if chat_template.startswith('http'):
response = requests.get(chat_template)
if response.status_code == 200:
tokenizer.chat_template = response.text
else:
raise ValueError("Failed to download chat template "
f"from {chat_template}")
else:
with open(chat_template, "r") as f:
tokenizer.chat_template = f.read()
except OSError as e:
JINJA_CHARS = "{}\n"
if not any(c in chat_template for c in JINJA_CHARS):
msg = (f"The supplied chat template ({chat_template}) "
f"looks like a file path, but it failed to be "
f"opened. Reason: {e}")
raise ValueError(msg) from e

# If opening a file fails, set chat template to be args to
# ensure we decode so our escape are interpreted correctly
tokenizer.chat_template = codecs.decode(chat_template,
"unicode_escape")

logger.info("Using supplied chat template")
elif tokenizer.chat_template is not None:
logger.info("Using default chat template")
else:
logger.warning("No chat template provided. Chat API will not work.")
def load_chat_template(chat_template: Optional[str]) -> Optional[str]:
if chat_template is None:
return None
try:
if chat_template.startswith(('http')):
response = requests.get(chat_template)
temp = tempfile.NamedTemporaryFile(delete=False)
temp.write(response.content)
temp.close()
chat_template = temp.name

with open(chat_template, "r") as f:
resolved_chat_template = f.read()
except OSError as e:
JINJA_CHARS = "{}\n"
if not any(c in chat_template for c in JINJA_CHARS):
msg = (f"The supplied chat template ({chat_template}) "
"looks like a file path, but it failed to be "
f"opened. Reason: {e}")
raise ValueError(msg) from e

# If opening a file fails, set chat template to be args to
# ensure we decode so our escape are interpreted correctly
resolved_chat_template = codecs.decode(chat_template, "unicode_escape")

logger.info(f"Using supplied chat template:\n{resolved_chat_template}")
return resolved_chat_template


@lru_cache(maxsize=None)
def _image_token_str(engine: OpenAIServing) -> Optional[str]:
def _image_token_str(model_config: ModelConfig,
tokenizer: PreTrainedTokenizer) -> Optional[str]:
# TODO: Let user specify how to insert image tokens into prompt
# (similar to chat template)
model_type = engine.model_config.hf_config.model_type
model_type = model_config.hf_config.model_type
if model_type == "phi3_v":
# Workaround since this token is not defined in the tokenizer
return "<|image_1|>"
if model_type in ("blip-2", "chatglm", "fuyu", "minicpmv", "paligemma"):
# These models do not use image tokens in the prompt
return None
if model_type.startswith("llava"):
return engine.tokenizer.decode(
engine.model_config.hf_config.image_token_index)
return tokenizer.decode(model_config.hf_config.image_token_index)

else:
raise TypeError(f"Unknown model type: {model_type}")
raise TypeError(f"Unknown model type: {model_type}")


# TODO: Let user specify how to insert image tokens into prompt
# (similar to chat template)
def _get_full_image_text_prompt(engine: OpenAIServing, image_token_str: str,
text_prompt: str) -> str:
def _get_full_image_text_prompt(image_token_str: str, text_prompt: str) -> str:
"""Combine image and text prompts for vision language model"""

# NOTE: For now we assume all model architectures use the same
Expand All @@ -94,9 +88,10 @@ def _get_full_image_text_prompt(engine: OpenAIServing, image_token_str: str,


def _parse_chat_message_content_parts(
engine: OpenAIServing,
role: str,
parts: Iterable[ChatCompletionContentPartParam],
model_config: ModelConfig,
tokenizer: PreTrainedTokenizer,
) -> ChatMessageParseResult:
texts: List[str] = []
mm_futures: List[Awaitable[MultiModalDataDict]] = []
Expand Down Expand Up @@ -127,15 +122,14 @@ def _parse_chat_message_content_parts(
text_prompt = "\n".join(texts)

if mm_futures:
image_token_str = _image_token_str(engine)
image_token_str = _image_token_str(model_config, tokenizer)
if image_token_str is not None:
if image_token_str in text_prompt:
logger.warning(
"Detected image token string in the text prompt. "
"Skipping prompt formatting.")
else:
text_prompt = _get_full_image_text_prompt(
engine,
image_token_str=image_token_str,
text_prompt=text_prompt,
)
Expand All @@ -146,8 +140,9 @@ def _parse_chat_message_content_parts(


def parse_chat_message_content(
engine: OpenAIServing,
message: ChatCompletionMessageParam,
model_config: ModelConfig,
tokenizer: PreTrainedTokenizer,
) -> ChatMessageParseResult:
role = message["role"]
content = message.get("content")
Expand All @@ -158,4 +153,5 @@ def parse_chat_message_content(
messages = [ConversationMessage(role=role, content=content)]
return ChatMessageParseResult(messages=messages, mm_futures=[])

return _parse_chat_message_content_parts(engine, role, content)
return _parse_chat_message_content_parts(role, content, model_config,
tokenizer)
2 changes: 2 additions & 0 deletions aphrodite/endpoints/openai/protocol.py
Original file line number Diff line number Diff line change
Expand Up @@ -765,6 +765,7 @@ class BatchRequestOutput(OpenAIBaseModel):


class TokenizeRequest(OpenAIBaseModel):
model: Optional[str]
add_generation_prompt: bool = Field(default=True)
add_special_tokens: bool = Field(default=False)
prompt: Optional[str] = Field(default=None)
Expand All @@ -778,6 +779,7 @@ class TokenizeResponse(OpenAIBaseModel):


class DetokenizeRequest(OpenAIBaseModel):
model: Optional[str]
tokens: List[int]


Expand Down
71 changes: 42 additions & 29 deletions aphrodite/endpoints/openai/serving_chat.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@

from fastapi import Request
from loguru import logger
from transformers import PreTrainedTokenizer

from aphrodite.common.config import ModelConfig
from aphrodite.common.outputs import RequestOutput
Expand Down Expand Up @@ -45,7 +46,8 @@ def __init__(self,
lora_modules=lora_modules)

self.response_role = response_role
load_chat_template(self, chat_template)
# If this is None we use the tokenizer's default chat template
self.chat_template = load_chat_template(chat_template)

async def create_chat_completion(
self,
Expand All @@ -67,11 +69,14 @@ async def create_chat_completion(
return error_check_ret

try:
_, lora_request = self._maybe_get_adapter(request)
tokenizer = await self.engine.get_tokenizer(lora_request)
conversation: List[ConversationMessage] = []
mm_futures: List[Awaitable[MultiModalDataDict]] = []

for msg in request.messages:
chat_parsed_result = parse_chat_message_content(self, msg)
chat_parsed_result = parse_chat_message_content(
msg, self.model_config, tokenizer)

conversation.extend(chat_parsed_result.messages)
mm_futures.extend(chat_parsed_result.mm_futures)
Expand All @@ -80,13 +85,13 @@ async def create_chat_completion(
tool.model_dump() for tool in request.tools
]

prompt = self.tokenizer.apply_chat_template(
prompt = tokenizer.apply_chat_template(
conversation=conversation,
tokenize=False,
add_generation_prompt=request.add_generation_prompt,
tools=tool_dicts,
documents=request.documents,
chat_template=request.chat_template,
chat_template=request.chat_template or self.chat_template,
**(request.chat_template_kwargs or {}),
)
except Exception as e:
Expand All @@ -108,19 +113,19 @@ async def create_chat_completion(
request_id = f"cmpl-{random_uuid()}"
try:
# Tokenize/detokenize depending on prompt format (string/token list)
prompt_ids, prompt_text = self._validate_prompt_and_tokenize(
prompt_ids, prompt_text = await self._validate_prompt_and_tokenize(
request,
tokenizer,
prompt=prompt,
add_special_tokens=request.add_special_tokens)
sampling_params = request.to_sampling_params()
_, lora_request = self._maybe_get_adapter(request)
decoding_config = await self.engine.get_decoding_config()
guided_decoding_backend = request.guided_decoding_backend \
or decoding_config.guided_decoding_backend
guided_decode_logits_processor = (
await get_guided_decoding_logits_processor(
guided_decoding_backend, request, await
self.engine.get_tokenizer()))
await
get_guided_decoding_logits_processor(guided_decoding_backend,
request, tokenizer))
if guided_decode_logits_processor:
if sampling_params.logits_processors is None:
sampling_params.logits_processors = []
Expand All @@ -145,12 +150,12 @@ async def create_chat_completion(
# Streaming response
if request.stream:
return self.chat_completion_stream_generator(
request, result_generator, request_id, conversation)
request, result_generator, request_id, conversation, tokenizer)
else:
try:
return await self.chat_completion_full_generator(
request, raw_request, result_generator, request_id,
conversation)
conversation, tokenizer)
except ValueError as e:
# TODO: Use an aphrodite-specific Validation Error
return self.create_error_response(str(e))
Expand All @@ -162,9 +167,12 @@ def get_chat_request_role(self, request: ChatCompletionRequest) -> str:
return request.messages[-1]["role"]

async def chat_completion_stream_generator(
self, request: ChatCompletionRequest,
result_generator: AsyncIterator[RequestOutput], request_id: str,
conversation: List[ConversationMessage]
self,
request: ChatCompletionRequest,
result_generator: AsyncIterator[RequestOutput],
request_id: str,
conversation: List[ConversationMessage],
tokenizer: PreTrainedTokenizer,
) -> AsyncGenerator[str, None]:
model_name = self.served_model_names[0]
created_time = int(time.time())
Expand Down Expand Up @@ -251,6 +259,7 @@ async def chat_completion_stream_generator(
logprobs = self._create_chat_logprobs(
token_ids=delta_token_ids,
top_logprobs=out_logprobs,
tokenizer=tokenizer,
num_output_top_logprobs=request.top_logprobs,
)
else:
Expand Down Expand Up @@ -339,9 +348,13 @@ async def chat_completion_stream_generator(
yield "data: [DONE]\n\n"

async def chat_completion_full_generator(
self, request: ChatCompletionRequest, raw_request: Optional[Request],
result_generator: AsyncIterator[RequestOutput], request_id: str,
conversation: List[ConversationMessage]
self,
request: ChatCompletionRequest,
raw_request: Optional[Request],
result_generator: AsyncIterator[RequestOutput],
request_id: str,
conversation: List[ConversationMessage],
tokenizer: PreTrainedTokenizer,
) -> Union[ErrorResponse, ChatCompletionResponse]:

model_name = self.served_model_names[0]
Expand All @@ -368,6 +381,7 @@ async def chat_completion_full_generator(
logprobs = self._create_chat_logprobs(
token_ids=token_ids,
top_logprobs=out_logprobs,
tokenizer=tokenizer,
num_output_top_logprobs=request.top_logprobs,
)
else:
Expand Down Expand Up @@ -423,16 +437,14 @@ async def chat_completion_full_generator(
return response

def _get_top_logprobs(
self, logprobs: Dict[int, Logprob],
top_logprobs: Optional[int]) -> List[ChatCompletionLogProb]:
self, logprobs: Dict[int, Logprob], top_logprobs: Optional[int],
tokenizer: PreTrainedTokenizer) -> List[ChatCompletionLogProb]:
return [
ChatCompletionLogProb(
token=self._get_decoded_token(p[1], p[0]),
token=(token := self._get_decoded_token(p[1], p[0],
tokenizer)),
logprob=max(p[1].logprob, -9999.0),
bytes=list(
self._get_decoded_token(p[1],
p[0]).encode("utf-8",
errors="replace")))
bytes=list(token.encode("utf-8", errors="replace")))
for i, p in enumerate(logprobs.items())
if top_logprobs and i < top_logprobs
]
Expand All @@ -441,6 +453,7 @@ def _create_chat_logprobs(
self,
token_ids: GenericSequence[int],
top_logprobs: GenericSequence[Optional[Dict[int, Logprob]]],
tokenizer: PreTrainedTokenizer,
num_output_top_logprobs: Optional[int] = None,
) -> ChatCompletionLogProbs:
"""Create OpenAI-style logprobs."""
Expand All @@ -450,12 +463,11 @@ def _create_chat_logprobs(
for i, token_id in enumerate(token_ids):
step_top_logprobs = top_logprobs[i]
if step_top_logprobs is None:
token = tokenizer.decode(token_id)
logprobs_content.append(
ChatCompletionLogProbsContent(
token=self.tokenizer.decode(token_id),
bytes=list(
self.tokenizer.decode(token_id).encode(
"utf-8", errors="replace"))))
token=token,
bytes=list(token.encode("utf-8", errors="replace"))))
else:
logprobs_content.append(
ChatCompletionLogProbsContent(
Expand All @@ -466,6 +478,7 @@ def _create_chat_logprobs(
step_top_logprobs[token_id].decoded_token.encode(
"utf-8", errors="replace")),
top_logprobs=self._get_top_logprobs(
step_top_logprobs, num_output_top_logprobs)))
step_top_logprobs, num_output_top_logprobs,
tokenizer)))

return ChatCompletionLogProbs(content=logprobs_content)
Loading

0 comments on commit be4438d

Please sign in to comment.