diff --git a/inference/server/oasst_inference_server/routes/chats.py b/inference/server/oasst_inference_server/routes/chats.py index 37c47b455f..6b098bfc2f 100644 --- a/inference/server/oasst_inference_server/routes/chats.py +++ b/inference/server/oasst_inference_server/routes/chats.py @@ -140,6 +140,7 @@ async def create_assistant_message( work_parameters = inference.WorkParameters( model_config=model_config, sampling_parameters=request.sampling_parameters, + system_prompt=request.system_prompt, plugins=request.plugins, plugin_max_depth=settings.plugin_max_depth, ) diff --git a/inference/server/oasst_inference_server/schemas/chat.py b/inference/server/oasst_inference_server/schemas/chat.py index c53965094e..74ec4a09ac 100644 --- a/inference/server/oasst_inference_server/schemas/chat.py +++ b/inference/server/oasst_inference_server/schemas/chat.py @@ -14,6 +14,7 @@ class CreateAssistantMessageRequest(pydantic.BaseModel): parent_id: str model_config_name: str sampling_parameters: inference.SamplingParameters = pydantic.Field(default_factory=inference.SamplingParameters) + system_prompt: str | None = None plugins: list[inference.PluginEntry] = pydantic.Field(default_factory=list[inference.PluginEntry]) used_plugin: inference.PluginUsed | None = None diff --git a/inference/worker/chat_chain_prompts.py b/inference/worker/chat_chain_prompts.py index 310ec59609..299dfa8f57 100644 --- a/inference/worker/chat_chain_prompts.py +++ b/inference/worker/chat_chain_prompts.py @@ -1,5 +1,6 @@ V2_ASST_PREFIX = "<|assistant|>" V2_PROMPTER_PREFIX = "<|prompter|>" +V2_SYSTEM_PREFIX = "<|system|>" ASSISTANT_PREFIX = "Open Assistant" HUMAN_PREFIX = "Human" diff --git a/inference/worker/utils.py b/inference/worker/utils.py index c3528e4f1a..fa1ce6dfad 100644 --- a/inference/worker/utils.py +++ b/inference/worker/utils.py @@ -11,7 +11,7 @@ import sseclient import transformers import websocket -from chat_chain_prompts import V2_PROMPTER_PREFIX +from chat_chain_prompts import V2_PROMPTER_PREFIX, V2_SYSTEM_PREFIX from loguru import logger from oasst_shared.schemas import inference from settings import settings @@ -80,12 +80,23 @@ def truncate_prompt( ): with shared_tokenizer_lock: ids = tokenizer.encode(prompt) + prompter_prefix_id = tokenizer.convert_tokens_to_ids(V2_PROMPTER_PREFIX) + + system_prompt: str | None = None + system_tokens: list[int] | None = None + if prompt.startswith(V2_SYSTEM_PREFIX): + system_prompt = prompt[: prompt.index(V2_PROMPTER_PREFIX)] + system_tokens = ids[: ids.index(prompter_prefix_id)] max_input_length = get_max_input_length(worker_config, plugin_used) if len(ids) > max_input_length: logger.debug(f"Prompt too long, left-truncating to {max_input_length} tokens") - ids = ids[-(max_input_length - 1) :] + + num_system_tokens = len(system_tokens) if system_tokens else 0 + # Maximum token allowed for the conversation, ex system prompt + max_conversation_length = max_input_length - num_system_tokens + ids = ids[-(max_conversation_length - 1) :] with shared_tokenizer_lock: prompt = tokenizer.decode(ids) @@ -94,6 +105,10 @@ def truncate_prompt( prompt = V2_PROMPTER_PREFIX + prompt ids = tokenizer.encode(V2_PROMPTER_PREFIX) + ids + if system_tokens: + prompt = system_prompt + prompt + ids = system_tokens + ids + max_total_tokens = worker_config.model_config.max_total_length input_length = len(ids) spare = max_total_tokens - input_length - 1 diff --git a/inference/worker/work.py b/inference/worker/work.py index bb4a5e9f10..45fd197f5c 100644 --- a/inference/worker/work.py +++ b/inference/worker/work.py @@ -15,6 +15,7 @@ THOUGHT_SEQ, V2_ASST_PREFIX, V2_PROMPTER_PREFIX, + V2_SYSTEM_PREFIX, ) from loguru import logger from oasst_shared.schemas import inference @@ -38,6 +39,10 @@ def _prepare_message(message: inference.MessageRead) -> str: # construct prompt messages = [_prepare_message(message) for message in work_request.thread.messages] + if work_request.parameters.system_prompt: + pre_prompt = V2_SYSTEM_PREFIX + work_request.parameters.system_prompt + eos_token + messages = [pre_prompt] + messages + prompt = "".join(messages) + V2_ASST_PREFIX parameters = interface.GenerateStreamParameters.from_work_parameters(work_request.parameters) @@ -45,6 +50,7 @@ def _prepare_message(message: inference.MessageRead) -> str: parameters.stop = [ V2_PROMPTER_PREFIX, V2_ASST_PREFIX, + V2_SYSTEM_PREFIX, ] if eos_token: parameters.stop.append(eos_token) diff --git a/oasst-shared/oasst_shared/schemas/inference.py b/oasst-shared/oasst_shared/schemas/inference.py index fe6e7467b7..9ee7d129e9 100644 --- a/oasst-shared/oasst_shared/schemas/inference.py +++ b/oasst-shared/oasst_shared/schemas/inference.py @@ -200,6 +200,7 @@ class WorkParameters(pydantic.BaseModel): seed: int = pydantic.Field( default_factory=make_seed, ) + system_prompt: str | None = None plugins: list[PluginEntry] = pydantic.Field(default_factory=list[PluginEntry]) plugin_max_depth: int = 4