Skip to content

Commit

Permalink
Add inference server setting to control max depth of plugins (#3227)
Browse files Browse the repository at this point in the history
Co-authored-by: Andreas Köpf <andreas.koepf@provisio.com>
  • Loading branch information
olliestanley and andreaskoepf authored May 26, 2023
1 parent e059d86 commit dcda41c
Show file tree
Hide file tree
Showing 5 changed files with 21 additions and 6 deletions.
3 changes: 3 additions & 0 deletions ansible/inference/deploy-server.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -161,5 +161,8 @@
MESSAGE_MAX_LENGTH:
"{{ lookup('ansible.builtin.env', 'INFERENCE_MESSAGE_MAX_LENGTH') |
default('', true) }}"
PLUGIN_MAX_DEPTH:
"{{ lookup('ansible.builtin.env', 'INFERENCE_PLUGIN_MAX_DEPTH') |
default(4, true) }}"
ports:
- "{{ server_port }}:8080"
1 change: 1 addition & 0 deletions inference/server/oasst_inference_server/routes/chats.py
Original file line number Diff line number Diff line change
Expand Up @@ -141,6 +141,7 @@ async def create_assistant_message(
model_config=model_config,
sampling_parameters=request.sampling_parameters,
plugins=request.plugins,
plugin_max_depth=settings.plugin_max_depth,
)
assistant_message = await ucr.initiate_assistant_message(
parent_id=request.parent_id,
Expand Down
3 changes: 3 additions & 0 deletions inference/server/oasst_inference_server/settings.py
Original file line number Diff line number Diff line change
Expand Up @@ -112,6 +112,9 @@ def trusted_api_keys_list(self) -> list[str]:

inference_cors_origins: str = "*"

# sent as a work parameter, higher values increase load on workers
plugin_max_depth: int = 4

@property
def inference_cors_origins_list(self) -> list[str]:
return self.inference_cors_origins.split(",")
Expand Down
19 changes: 13 additions & 6 deletions inference/worker/chat_chain.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,9 +25,6 @@
from oasst_shared.schemas import inference
from settings import settings

# Max depth of retries for tool usage
MAX_DEPTH = 6

# Exclude tools description from final prompt. Saves ctx space but can hurt output
# quality especially if truncation kicks in. Dependent on model used
REMOVE_TOOLS_FROM_FINAL_PROMPT = False
Expand Down Expand Up @@ -110,6 +107,7 @@ def handle_plugin_usage(
parameters: interface.GenerateStreamParameters,
tools: list[Tool],
plugin: inference.PluginEntry | None,
plugin_max_depth: int,
) -> tuple[str, inference.PluginUsed]:
execution_details = inference.PluginExecutionDetails(
inner_monologue=[],
Expand Down Expand Up @@ -155,7 +153,7 @@ def handle_plugin_usage(
assisted = False if ASSISTANT_PREFIX in prefix else True
chain_finished = not assisted

while not chain_finished and assisted and achieved_depth < MAX_DEPTH:
while not chain_finished and assisted and achieved_depth < plugin_max_depth:
tool_response = use_tool(prefix, response, tools)

# Save previous chain response for use in final prompt
Expand Down Expand Up @@ -238,7 +236,7 @@ def handle_plugin_usage(
plugin_used.execution_details.final_prompt = init_prompt
plugin_used.execution_details.achieved_depth = achieved_depth
plugin_used.execution_details.status = "failure"
plugin_used.execution_details.error_message = f"Max depth reached: {MAX_DEPTH}"
plugin_used.execution_details.error_message = f"Max depth reached: {plugin_max_depth}"
init_prompt = f"{init_prompt}{THOUGHT_SEQ} I now know the final answer\n{ASSISTANT_PREFIX}: "
return init_prompt, plugin_used

Expand Down Expand Up @@ -315,7 +313,16 @@ def handle_conversation(

if plugin_enabled:
return handle_plugin_usage(
original_prompt, prompt_template, language, memory, worker_config, tokenizer, parameters, tools, plugin
original_prompt,
prompt_template,
language,
memory,
worker_config,
tokenizer,
parameters,
tools,
plugin,
work_request.parameters.plugin_max_depth,
)

return handle_standard_usage(original_prompt, prompt_template, language, memory, worker_config, tokenizer)
Expand Down
1 change: 1 addition & 0 deletions oasst-shared/oasst_shared/schemas/inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -201,6 +201,7 @@ class WorkParameters(pydantic.BaseModel):
default_factory=make_seed,
)
plugins: list[PluginEntry] = pydantic.Field(default_factory=list[PluginEntry])
plugin_max_depth: int = 4


class ReportType(str, enum.Enum):
Expand Down

0 comments on commit dcda41c

Please sign in to comment.