Skip to content

Commit

Permalink
Inference backend support system prompt (#3313)
Browse files Browse the repository at this point in the history
  • Loading branch information
olliestanley authored Jun 7, 2023
1 parent fe5f2c5 commit 0fcf3e0
Show file tree
Hide file tree
Showing 6 changed files with 27 additions and 2 deletions.
1 change: 1 addition & 0 deletions inference/server/oasst_inference_server/routes/chats.py
Original file line number Diff line number Diff line change
Expand Up @@ -140,6 +140,7 @@ async def create_assistant_message(
work_parameters = inference.WorkParameters(
model_config=model_config,
sampling_parameters=request.sampling_parameters,
system_prompt=request.system_prompt,
plugins=request.plugins,
plugin_max_depth=settings.plugin_max_depth,
)
Expand Down
1 change: 1 addition & 0 deletions inference/server/oasst_inference_server/schemas/chat.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@ class CreateAssistantMessageRequest(pydantic.BaseModel):
parent_id: str
model_config_name: str
sampling_parameters: inference.SamplingParameters = pydantic.Field(default_factory=inference.SamplingParameters)
system_prompt: str | None = None
plugins: list[inference.PluginEntry] = pydantic.Field(default_factory=list[inference.PluginEntry])
used_plugin: inference.PluginUsed | None = None

Expand Down
1 change: 1 addition & 0 deletions inference/worker/chat_chain_prompts.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
V2_ASST_PREFIX = "<|assistant|>"
V2_PROMPTER_PREFIX = "<|prompter|>"
V2_SYSTEM_PREFIX = "<|system|>"

ASSISTANT_PREFIX = "Open Assistant"
HUMAN_PREFIX = "Human"
Expand Down
19 changes: 17 additions & 2 deletions inference/worker/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
import sseclient
import transformers
import websocket
from chat_chain_prompts import V2_PROMPTER_PREFIX
from chat_chain_prompts import V2_PROMPTER_PREFIX, V2_SYSTEM_PREFIX
from loguru import logger
from oasst_shared.schemas import inference
from settings import settings
Expand Down Expand Up @@ -80,12 +80,23 @@ def truncate_prompt(
):
with shared_tokenizer_lock:
ids = tokenizer.encode(prompt)
prompter_prefix_id = tokenizer.convert_tokens_to_ids(V2_PROMPTER_PREFIX)

system_prompt: str | None = None
system_tokens: list[int] | None = None
if prompt.startswith(V2_SYSTEM_PREFIX):
system_prompt = prompt[: prompt.index(V2_PROMPTER_PREFIX)]
system_tokens = ids[: ids.index(prompter_prefix_id)]

max_input_length = get_max_input_length(worker_config, plugin_used)

if len(ids) > max_input_length:
logger.debug(f"Prompt too long, left-truncating to {max_input_length} tokens")
ids = ids[-(max_input_length - 1) :]

num_system_tokens = len(system_tokens) if system_tokens else 0
# Maximum token allowed for the conversation, ex system prompt
max_conversation_length = max_input_length - num_system_tokens
ids = ids[-(max_conversation_length - 1) :]

with shared_tokenizer_lock:
prompt = tokenizer.decode(ids)
Expand All @@ -94,6 +105,10 @@ def truncate_prompt(
prompt = V2_PROMPTER_PREFIX + prompt
ids = tokenizer.encode(V2_PROMPTER_PREFIX) + ids

if system_tokens:
prompt = system_prompt + prompt
ids = system_tokens + ids

max_total_tokens = worker_config.model_config.max_total_length
input_length = len(ids)
spare = max_total_tokens - input_length - 1
Expand Down
6 changes: 6 additions & 0 deletions inference/worker/work.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
THOUGHT_SEQ,
V2_ASST_PREFIX,
V2_PROMPTER_PREFIX,
V2_SYSTEM_PREFIX,
)
from loguru import logger
from oasst_shared.schemas import inference
Expand All @@ -38,13 +39,18 @@ def _prepare_message(message: inference.MessageRead) -> str:
# construct prompt
messages = [_prepare_message(message) for message in work_request.thread.messages]

if work_request.parameters.system_prompt:
pre_prompt = V2_SYSTEM_PREFIX + work_request.parameters.system_prompt + eos_token
messages = [pre_prompt] + messages

prompt = "".join(messages) + V2_ASST_PREFIX

parameters = interface.GenerateStreamParameters.from_work_parameters(work_request.parameters)
if settings.use_stop_sequences:
parameters.stop = [
V2_PROMPTER_PREFIX,
V2_ASST_PREFIX,
V2_SYSTEM_PREFIX,
]
if eos_token:
parameters.stop.append(eos_token)
Expand Down
1 change: 1 addition & 0 deletions oasst-shared/oasst_shared/schemas/inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -200,6 +200,7 @@ class WorkParameters(pydantic.BaseModel):
seed: int = pydantic.Field(
default_factory=make_seed,
)
system_prompt: str | None = None
plugins: list[PluginEntry] = pydantic.Field(default_factory=list[PluginEntry])
plugin_max_depth: int = 4

Expand Down

0 comments on commit 0fcf3e0

Please sign in to comment.