Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add tool_choice for OpenAI and Anthropic #142

Merged
merged 6 commits into from
Nov 7, 2024
Merged
Show file tree
Hide file tree
Changes from 5 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 15 additions & 1 deletion lib/chat_models/chat_anthropic.ex
Original file line number Diff line number Diff line change
Expand Up @@ -113,6 +113,9 @@ defmodule LangChain.ChatModels.ChatAnthropic do

# A list of maps for callback handlers
field :callbacks, {:array, :map}, default: []

# Tool choice option
field :tool_choice, :map
end

@type t :: %ChatAnthropic{}
Expand All @@ -128,7 +131,8 @@ defmodule LangChain.ChatModels.ChatAnthropic do
:top_p,
:top_k,
:stream,
:callbacks
:callbacks,
:tool_choice
]
@required_fields [:endpoint, :model]

Expand Down Expand Up @@ -201,11 +205,21 @@ defmodule LangChain.ChatModels.ChatAnthropic do
# Anthropic sets the `system` message on the request body, not as part of the messages list.
|> Utils.conditionally_add_to_map(:system, system_text)
|> Utils.conditionally_add_to_map(:tools, get_tools_for_api(tools))
|> Utils.conditionally_add_to_map(:tool_choice, get_tool_choice(anthropic))
|> Utils.conditionally_add_to_map(:max_tokens, anthropic.max_tokens)
|> Utils.conditionally_add_to_map(:top_p, anthropic.top_p)
|> Utils.conditionally_add_to_map(:top_k, anthropic.top_k)
end

defp get_tool_choice(%ChatAnthropic{tool_choice: %{"type" => "tool", "name" => name}=_tool_choice}) when is_binary(name) and byte_size(name) > 0,
do: %{"type" => "tool", "name" => name}

defp get_tool_choice(%ChatAnthropic{tool_choice: %{"type" => type}=_tool_choice}) when is_binary(type) and byte_size(type) > 0,
do: %{"type" => type}

defp get_tool_choice(%ChatAnthropic{}), do: nil


defp get_tools_for_api(nil), do: []

defp get_tools_for_api(tools) do
Expand Down
25 changes: 22 additions & 3 deletions lib/chat_models/chat_open_ai.ex
Original file line number Diff line number Diff line change
Expand Up @@ -132,6 +132,9 @@ defmodule LangChain.ChatModels.ChatOpenAI do
# streaming.
field :stream_options, :map, default: nil

# Tool choice option
field :tool_choice, :map

# A list of maps for callback handlers
field :callbacks, {:array, :map}, default: []

Expand All @@ -157,7 +160,8 @@ defmodule LangChain.ChatModels.ChatOpenAI do
:max_tokens,
:stream_options,
:user,
:callbacks
:callbacks,
:tool_choice
]
@required_fields [:endpoint, :model]

Expand Down Expand Up @@ -242,6 +246,7 @@ defmodule LangChain.ChatModels.ChatOpenAI do
get_stream_options_for_api(openai.stream_options)
)
|> Utils.conditionally_add_to_map(:tools, get_tools_for_api(tools))
|> Utils.conditionally_add_to_map(:tool_choice, get_tool_choice(openai))
end

defp get_tools_for_api(nil), do: []
Expand All @@ -265,6 +270,18 @@ defmodule LangChain.ChatModels.ChatOpenAI do
defp set_response_format(%ChatOpenAI{json_response: false}),
do: %{"type" => "text"}

defp get_tool_choice(%ChatOpenAI{
tool_choice: %{"type" => "function", "function" => %{"name" => name}} = _tool_choice
})
when is_binary(name) and byte_size(name) > 0,
do: %{"type" => "function", "function" => %{"name" => name}}

defp get_tool_choice(%ChatOpenAI{tool_choice: %{"type" => type} = _tool_choice})
when is_binary(type) and byte_size(type) > 0,
do: type

defp get_tool_choice(%ChatOpenAI{}), do: nil

@doc """
Convert a LangChain structure to the expected map of data for the OpenAI API.
"""
Expand Down Expand Up @@ -692,8 +709,10 @@ defmodule LangChain.ChatModels.ChatOpenAI do
# Full message with tool call
def do_process_response(
model,
%{"finish_reason" => "tool_calls", "message" => %{"tool_calls" => calls} = message} = data
) do
%{"finish_reason" => finish_reason, "message" => %{"tool_calls" => calls} = message} =
data
)
when finish_reason in ["tool_calls", "stop"] do
case Message.new(%{
"role" => "assistant",
"content" => message["content"],
Expand Down
55 changes: 55 additions & 0 deletions test/chat_models/chat_anthropic_test.exs
Original file line number Diff line number Diff line change
Expand Up @@ -102,6 +102,30 @@ defmodule LangChain.ChatModels.ChatAnthropicTest do
assert data.max_tokens == 1234
end

test "generated a map for an API call with tool_choice set correctly to auto" do
{:ok, anthropic} =
ChatAnthropic.new(%{
model: @test_model,
tool_choice: %{"type" => "auto"}
})

data = ChatAnthropic.for_api(anthropic, [], [])
assert data.model == @test_model
assert data.tool_choice == %{"type" => "auto"}
end

test "generated a map for an API call with tool_choice set correctly to a specific function" do
{:ok, anthropic} =
ChatAnthropic.new(%{
model: @test_model,
tool_choice: %{"type" => "tool", "name" => "get_weather"}
})

data = ChatAnthropic.for_api(anthropic, [], [])
assert data.model == @test_model
assert data.tool_choice == %{"type" => "tool", "name" => "get_weather"}
end

test "adds tool definitions to map" do
tool =
Function.new!(%{
Expand Down Expand Up @@ -1152,6 +1176,37 @@ data: {"type":"content_block_delta","index":0,"delta":{"type":"text_delta","text
# ],
# }
end

@tag live_call: true, live_anthropic: true
test "executes a call with tool_choice set as a specific name" do
# https://docs.anthropic.com/claude/reference/messages-examples#vision
{:ok, chat} = ChatAnthropic.new(%{model: @test_model, tool_choice: %{"type" => "tool", "name" => "do_another_thing"}})

message = Message.new_user!("Use the 'do_something' tool with the value 'cat', or use 'do_another_thing' tool with the name 'foo'")

tool_1 =
Function.new!(%{
name: "do_something",
parameters: [FunctionParam.new!(%{type: :string, name: "value", required: true})],
function: fn _args, _context -> :ok end
})

tool_2 =
Function.new!(%{
name: "do_another_thing",
parameters: [FunctionParam.new!(%{type: :string, name: "name", required: true})],
function: fn _args, _context -> :ok end
})

{:ok, response} = ChatAnthropic.call(chat, [message], [tool_1, tool_2])

assert %Message{role: :assistant} = response
assert [%ToolCall{} = call] = response.tool_calls
assert call.status == :complete
assert call.type == :function
assert call.name == "do_another_thing"
assert call.arguments == %{"name" => "foo"}
end
end

describe "works within a chain" do
Expand Down
72 changes: 72 additions & 0 deletions test/chat_models/chat_open_ai_test.exs
Original file line number Diff line number Diff line change
Expand Up @@ -135,6 +135,30 @@ defmodule LangChain.ChatModels.ChatOpenAITest do
assert data.model == @test_model
assert data.stream_options == %{"include_usage" => true}
end

test "generated a map for an API call with tool_choice set correctly to auto" do
{:ok, openai} =
ChatOpenAI.new(%{
model: @test_model,
tool_choice: %{"type" => "auto"}
})

data = ChatOpenAI.for_api(openai, [], [])
assert data.model == @test_model
assert data.tool_choice == "auto"
end

test "generated a map for an API call with tool_choice set correctly to a specific function" do
{:ok, openai} =
ChatOpenAI.new(%{
model: @test_model,
tool_choice: %{"type" => "function", "function" => %{"name" => "set_weather"}}
})

data = ChatOpenAI.for_api(openai, [], [])
assert data.model == @test_model
assert data.tool_choice == %{"type" => "function", "function" => %{"name" => "set_weather"}}
end
end

describe "for_api/1" do
Expand Down Expand Up @@ -689,6 +713,54 @@ defmodule LangChain.ChatModels.ChatOpenAITest do
assert call.arguments == %{"city" => "Moab", "state" => "UT"}
end

@tag live_call: true, live_open_ai: true
test "executing a call with tool_choice set as none", %{
weather: weather,
hello_world: hello_world
} do
{:ok, chat} =
ChatOpenAI.new(%{seed: 0, stream: false, model: @gpt4, tool_choice: %{"type" => "none"}})

{:ok, message} =
Message.new_user("What is the weather like in Moab Utah?")

{:ok, [message]} = ChatOpenAI.call(chat, [message], [weather, hello_world])

assert %Message{role: :assistant} = message
assert message.status == :complete
assert message.role == :assistant
assert message.content != nil
assert message.tool_calls == []
end

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Adding another live call test that forces the use of a tool triggers the bug I mentioned:

    @tag live_call: true, live_open_ai: true
    test "executing a call with required tool_choice", %{
      weather: weather,
      hello_world: hello_world
    } do
      {:ok, chat} =
        ChatOpenAI.new(%{
          seed: 0,
          stream: false,
          model: @gpt4,
          tool_choice: %{"type" => "function", "function" => %{"name" => "get_weather"}}
        })

      {:ok, message} =
        Message.new_user("What is the weather like in Moab Utah?")

      {:ok, [message]} = ChatOpenAI.call(chat, [message], [weather, hello_world])

      # assertions...
    end
  1) test call/2 executing a call with required tool_choice (LangChain.ChatModels.ChatOpenAITest)
     test/chat_models/chat_open_ai_test.exs:737
     ** (WithClauseError) no with clause matching: %{"function" => %{"arguments" => "{\"city\":\"Moab\",\"state\":\"UT\"}", "name" => "get_weather"}, "id" => "call_q5WG1gd5X8gCqf9UlDceB3WA", "type" => "function"}
     code: {:ok, [message]} = ChatOpenAI.call(chat, [message], [weather, hello_world])
     stacktrace:
       (langchain 0.3.0-rc.0) lib/message.ex:225: anonymous fn/1 in LangChain.Message.validate_and_parse_tool_calls/1
       (elixir 1.16.2) lib/enum.ex:1700: Enum."-map/2-lists^map/1-1-"/2
       (langchain 0.3.0-rc.0) lib/message.ex:224: LangChain.Message.validate_and_parse_tool_calls/1
       (langchain 0.3.0-rc.0) lib/message.ex:151: LangChain.Message.common_validations/1
       (langchain 0.3.0-rc.0) lib/message.ex:121: LangChain.Message.new/1
       (langchain 0.3.0-rc.0) lib/chat_models/chat_open_ai.ex:825: LangChain.ChatModels.ChatOpenAI.do_process_response/2
       (elixir 1.16.2) lib/enum.ex:1700: Enum."-map/2-lists^map/1-1-"/2
       (langchain 0.3.0-rc.0) lib/chat_models/chat_open_ai.ex:552: LangChain.ChatModels.ChatOpenAI.do_api_request/4
       (langchain 0.3.0-rc.0) lib/chat_models/chat_open_ai.ex:473: LangChain.ChatModels.ChatOpenAI.call/3
       test/chat_models/chat_open_ai_test.exs:752: (test)

These changes fix that issue: https://github.com/brainlid/langchain/pull/141/files#diff-98396b383d660bb5274db90a5fd13b1526d5d6b120a2d8ece0e8baf06e55b85eL695-R713

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This commit includes the fix as you described, as well as the live call test that produced the bug previously.

Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks @stevehodgkiss for the help reviewing!

@tag live_call: true, live_open_ai: true
test "executing a call with required tool_choice", %{
weather: weather,
hello_world: hello_world
} do
{:ok, chat} =
ChatOpenAI.new(%{
seed: 0,
stream: false,
model: @gpt4,
tool_choice: %{"type" => "function", "function" => %{"name" => "get_weather"}}
})

{:ok, message} =
Message.new_user("What is the weather like in Moab Utah?")

{:ok, [message]} = ChatOpenAI.call(chat, [message], [weather, hello_world])

assert %Message{role: :assistant} = message
assert message.status == :complete
assert message.role == :assistant
assert [%LangChain.Message.ToolCall{} = tool_call] = message.tool_calls
assert tool_call.name == "get_weather"
assert tool_call.type == :function
assert tool_call.status == :complete
assert is_map(tool_call.arguments)
end

@tag live_call: true, live_open_ai: true
test "LIVE: supports receiving multiple tool calls in a single response", %{weather: weather} do
{:ok, chat} =
Expand Down