Skip to content

Commit

Permalink
Merge pull request #293 from zmackie/zmackie/add-stop-config
Browse files Browse the repository at this point in the history
Adds configurable stop tokens
  • Loading branch information
drazvan committed Feb 15, 2024
2 parents db42b3e + f818197 commit 32f902c
Show file tree
Hide file tree
Showing 10 changed files with 76 additions and 7 deletions.
1 change: 1 addition & 0 deletions nemoguardrails/actions/llm/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,7 @@ async def llm_call(
all_callbacks = logging_callbacks

if isinstance(prompt, str):
# stop sinks here
result = await llm.agenerate_prompt(
[StringPromptValue(text=prompt)], callbacks=all_callbacks, stop=stop
)
Expand Down
3 changes: 2 additions & 1 deletion nemoguardrails/eval/evaluate_factcheck.py
Original file line number Diff line number Diff line change
Expand Up @@ -143,7 +143,8 @@ def check_facts(self, split="positive"):
fact_check_prompt = self.llm_task_manager.render_task_prompt(
Task.SELF_CHECK_FACTS, {"evidence": evidence, "response": answer}
)
fact_check = self.llm(fact_check_prompt)
stop = self.llm_task_manager.get_stop_tokens(Task.SELF_CHECK_FACTS)
fact_check = self.llm(fact_check_prompt, stop=stop)
end_time = time.time()
time.sleep(0.5) # avoid rate-limits
fact_check = fact_check.lower().strip()
Expand Down
3 changes: 2 additions & 1 deletion nemoguardrails/library/hallucination/actions.py
Original file line number Diff line number Diff line change
Expand Up @@ -114,9 +114,10 @@ async def check_hallucination(

# Initialize the LLMCallInfo object
llm_call_info_var.set(LLMCallInfo(task=Task.CHECK_HALLUCINATION.value))
stop = llm_task_manager.get_stop_tokens(task=Task.CHECK_HALLUCINATION)

with llm_params(llm, temperature=0.0):
agreement = await llm_call(llm, prompt)
agreement = await llm_call(llm, prompt, stop=stop)

agreement = agreement.lower().strip()
log.info(f"Agreement result for looking for hallucination is {agreement}.")
Expand Down
6 changes: 4 additions & 2 deletions nemoguardrails/library/llama_guard/actions.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,12 +70,13 @@ async def llama_guard_check_input(
"user_input": user_input,
},
)
stop = llm_task_manager.get_stop_tokens(task=Task.LLAMA_GUARD_CHECK_INPUT)

# Initialize the LLMCallInfo object
llm_call_info_var.set(LLMCallInfo(task=Task.SELF_CHECK_INPUT.value))

with llm_params(llama_guard_llm, temperature=0.0):
result = await llm_call(llama_guard_llm, check_input_prompt)
result = await llm_call(llama_guard_llm, check_input_prompt, stop=stop)

allowed, policy_violations = parse_llama_guard_response(result)
return {"allowed": allowed, "policy_violations": policy_violations}
Expand All @@ -101,12 +102,13 @@ async def llama_guard_check_output(
"bot_response": bot_response,
},
)
stop = llm_task_manager.get_stop_tokens(task=Task.LLAMA_GUARD_CHECK_OUTPUT)

# Initialize the LLMCallInfo object
llm_call_info_var.set(LLMCallInfo(task=Task.SELF_CHECK_OUTPUT.value))

with llm_params(llama_guard_llm, temperature=0.0):
result = await llm_call(llama_guard_llm, check_output_prompt)
result = await llm_call(llama_guard_llm, check_output_prompt, stop=stop)

allowed, policy_violations = parse_llama_guard_response(result)
return {"allowed": allowed, "policy_violations": policy_violations}
3 changes: 2 additions & 1 deletion nemoguardrails/library/self_check/facts/actions.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,12 +47,13 @@ async def self_check_facts(
"response": response,
},
)
stop = llm_task_manager.get_stop_tokens(task=Task.SELF_CHECK_FACTS)

# Initialize the LLMCallInfo object
llm_call_info_var.set(LLMCallInfo(task=Task.SELF_CHECK_FACTS.value))

with llm_params(llm, temperature=0.0):
entails = await llm_call(llm, prompt)
entails = await llm_call(llm, prompt, stop=stop)

entails = entails.lower().strip()

Expand Down
3 changes: 2 additions & 1 deletion nemoguardrails/library/self_check/input_check/actions.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,12 +54,13 @@ async def self_check_input(
"user_input": user_input,
},
)
stop = llm_task_manager.get_stop_tokens(task=Task.SELF_CHECK_INPUT)

# Initialize the LLMCallInfo object
llm_call_info_var.set(LLMCallInfo(task=Task.SELF_CHECK_INPUT.value))

with llm_params(llm, temperature=0.0):
check = await llm_call(llm, prompt)
check = await llm_call(llm, prompt, stop=stop)

check = check.lower().strip()
log.info(f"Input self-checking result is: `{check}`.")
Expand Down
3 changes: 2 additions & 1 deletion nemoguardrails/library/self_check/output_check/actions.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,12 +57,13 @@ async def self_check_output(
"bot_response": bot_response,
},
)
stop = llm_task_manager.get_stop_tokens(task=Task.SELF_CHECK_OUTPUT)

# Initialize the LLMCallInfo object
llm_call_info_var.set(LLMCallInfo(task=Task.SELF_CHECK_OUTPUT.value))

with llm_params(llm, temperature=0.0):
response = await llm_call(llm, prompt)
response = await llm_call(llm, prompt, stop=stop)

response = response.lower().strip()
log.info(f"Output self-checking result is: `{response}`.")
Expand Down
5 changes: 5 additions & 0 deletions nemoguardrails/llm/taskmanager.py
Original file line number Diff line number Diff line change
Expand Up @@ -255,6 +255,11 @@ def parse_task_output(self, task: Task, output: str):
else:
return output

def get_stop_tokens(self, task: Union[str, Task]) -> List[str]:
"""Return the stop sequence for the given task."""
prompt = get_prompt(self.config, task)
return prompt.stop

def register_filter(self, filter_fn: callable, name: Optional[str] = None):
"""Register a custom filter for the rails configuration."""
name = name or filter_fn.__name__
Expand Down
4 changes: 4 additions & 0 deletions nemoguardrails/rails/llm/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -138,6 +138,10 @@ class TaskPrompt(BaseModel):
default=_default_config["prompting_mode"],
description="Corresponds to the `prompting_mode` for which this prompt is fetched. Default is 'standard'.",
)
stop: Optional[List[str]] = Field(
default=None,
description="If specified, will be configure stop tokens for models that support this.",
)

@root_validator(pre=True, allow_reuse=True)
def check_fields(cls, values):
Expand Down
52 changes: 52 additions & 0 deletions tests/test_llm_task_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -222,3 +222,55 @@ def test_prompt_length_exceeded_compressed_history():
events=events,
)
assert len(generate_user_intent_prompt) <= max_task_prompt_length

# Test to check the stop configuration parameter


def test_stop_configuration_parameter():
"""Test the prompts for the OpenAI GPT-3 5 Turbo model."""
config = RailsConfig.from_content(
yaml_content=textwrap.dedent(
"""
models:
- type: main
engine: openai
model: gpt-3.5-turbo-instruct
prompts:
- task: generate_user_intent
stop:
- <<end>>
- <<stop>>
max_length: 3000
content: |-
{{ general_instructions }}
# This is how a conversation between a user and the bot can go:
{{ sample_conversation }}
# This is how the user talks:
{{ examples }}
# This is the current conversation between the user and the bot:
{{ sample_conversation | first_turns(2) }}
{{ history | colang }}
)
)"""
)
)

task_prompt = get_prompt(config, Task.GENERATE_USER_INTENT)

# Assuming the stop parameter is a list of strings
expected_stop_tokens = ["<<end>>", "<<stop>>"]
llm_task_manager = LLMTaskManager(config)

# Render the task prompt with the stop configuration
rendered_prompt = llm_task_manager.render_task_prompt(
task=Task.GENERATE_USER_INTENT,
context={"examples": 'user "Hello there!"\n express greeting'},
events=[],
)

# Check if the stop tokens are correctly set in the rendered prompt
for stop_token in expected_stop_tokens:
assert stop_token in task_prompt.stop

0 comments on commit 32f902c

Please sign in to comment.