Skip to content

Commit

Permalink
feat: Add system_instruction to LangchainAgent template.
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 656647049
  • Loading branch information
yeesian authored and copybara-github committed Jul 27, 2024
1 parent a02d82f commit c71c3dd
Show file tree
Hide file tree
Showing 2 changed files with 49 additions and 9 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@
_TEST_LOCATION = "us-central1"
_TEST_PROJECT = "test-project"
_TEST_MODEL = "gemini-1.0-pro"
_TEST_SYSTEM_INSTRUCTION = "You are a helpful bot."


def place_tool_query(
Expand Down Expand Up @@ -173,6 +174,7 @@ def test_initialization_with_tools(self, mock_chatvertexai):
]
agent = reasoning_engines.LangchainAgent(
model=_TEST_MODEL,
system_instruction=_TEST_SYSTEM_INSTRUCTION,
tools=tools,
)
for tool, agent_tool in zip(tools, agent._tools):
Expand Down Expand Up @@ -255,11 +257,6 @@ def test_enable_tracing_warning(self, caplog, langchain_instrumentor_none_mock):
assert "enable_tracing=True but proceeding with tracing disabled" in caplog.text


class TestConvertToolsOrRaise:
def test_convert_tools_or_raise(self, vertexai_init_mock):
pass


def _return_input_no_typing(input_):
"""Returns input back to user."""
return input_
Expand All @@ -272,3 +269,20 @@ def test_raise_untyped_input_args(self, vertexai_init_mock):
model=_TEST_MODEL,
tools=[_return_input_no_typing],
)


class TestSystemInstructionAndPromptRaisesErrors:
def test_raise_both_system_instruction_and_prompt_error(self, vertexai_init_mock):
with pytest.raises(
ValueError,
match=r"Only one of `prompt` or `system_instruction` should be specified.",
):
reasoning_engines.LangchainAgent(
model=_TEST_MODEL,
system_instruction=_TEST_SYSTEM_INSTRUCTION,
prompt=prompts.ChatPromptTemplate.from_messages(
[
("user", "{input}"),
]
),
)
34 changes: 30 additions & 4 deletions vertexai/preview/reasoning_engines/templates/langchain.py
Original file line number Diff line number Diff line change
Expand Up @@ -114,6 +114,7 @@ def _default_model_builder(
def _default_runnable_builder(
model: "BaseLanguageModel",
*,
system_instruction: Optional[str] = None,
tools: Optional[Sequence["_ToolLike"]] = None,
prompt: Optional["RunnableSerializable"] = None,
output_parser: Optional["RunnableSerializable"] = None,
Expand All @@ -131,7 +132,10 @@ def _default_runnable_builder(
# user would reflect that is by setting chat_history (which defaults to
# None).
has_history: bool = chat_history is not None
prompt = prompt or _default_prompt(has_history)
prompt = prompt or _default_prompt(
has_history=has_history,
system_instruction=system_instruction,
)
output_parser = output_parser or _default_output_parser()
model_tool_kwargs = model_tool_kwargs or {}
agent_executor_kwargs = agent_executor_kwargs or {}
Expand Down Expand Up @@ -162,7 +166,10 @@ def _default_runnable_builder(
return agent_executor


def _default_prompt(has_history: bool) -> "RunnableSerializable":
def _default_prompt(
has_history: bool,
system_instruction: Optional[str] = None,
) -> "RunnableSerializable":
from langchain_core import prompts

try:
Expand All @@ -173,6 +180,10 @@ def _default_prompt(has_history: bool) -> "RunnableSerializable":
format_to_openai_tool_messages as format_to_tool_messages,
)

system_instructions = []
if system_instruction:
system_instructions = [("system", system_instruction)]

if has_history:
return {
"history": lambda x: x["history"],
Expand All @@ -181,7 +192,8 @@ def _default_prompt(has_history: bool) -> "RunnableSerializable":
lambda x: format_to_tool_messages(x["intermediate_steps"])
),
} | prompts.ChatPromptTemplate.from_messages(
[
system_instructions
+ [
prompts.MessagesPlaceholder(variable_name="history"),
("user", "{input}"),
prompts.MessagesPlaceholder(variable_name="agent_scratchpad"),
Expand All @@ -194,7 +206,8 @@ def _default_prompt(has_history: bool) -> "RunnableSerializable":
lambda x: format_to_tool_messages(x["intermediate_steps"])
),
} | prompts.ChatPromptTemplate.from_messages(
[
system_instructions
+ [
("user", "{input}"),
prompts.MessagesPlaceholder(variable_name="agent_scratchpad"),
]
Expand Down Expand Up @@ -265,6 +278,7 @@ def __init__(
self,
model: str,
*,
system_instruction: Optional[str] = None,
prompt: Optional["RunnableSerializable"] = None,
tools: Optional[Sequence["_ToolLike"]] = None,
output_parser: Optional["RunnableSerializable"] = None,
Expand Down Expand Up @@ -319,6 +333,9 @@ def __init__(
Args:
model (str):
Optional. The name of the model (e.g. "gemini-1.0-pro").
system_instruction (str):
Optional. The system instruction to use for the agent. This
argument should not be specified if `prompt` is specified.
prompt (langchain_core.runnables.RunnableSerializable):
Optional. The prompt template for the model. Defaults to a
ChatPromptTemplate.
Expand Down Expand Up @@ -394,6 +411,7 @@ def __init__(
False.
Raises:
ValueError: If both `prompt` and `system_instruction` are specified.
TypeError: If there is an invalid tool (e.g. function with an input
that did not specify its type).
"""
Expand All @@ -407,7 +425,14 @@ def __init__(
# they are deployed.
_validate_tools(tools)
self._tools = tools
if prompt and system_instruction:
raise ValueError(
"Only one of `prompt` or `system_instruction` should be specified. "
"Consider incorporating the system instruction into the prompt "
"rather than passing it separately as an argument."
)
self._model_name = model
self._system_instruction = system_instruction
self._prompt = prompt
self._output_parser = output_parser
self._chat_history = chat_history
Expand Down Expand Up @@ -528,6 +553,7 @@ def set_up(self):
prompt=self._prompt,
model=self._model,
tools=self._tools,
system_instruction=self._system_instruction,
output_parser=self._output_parser,
chat_history=self._chat_history,
model_tool_kwargs=self._model_tool_kwargs,
Expand Down

0 comments on commit c71c3dd

Please sign in to comment.