Skip to content

Commit

Permalink
Advancing Tool Support - Part 4
Browse files Browse the repository at this point in the history
* Adopted new tool calling logic in OllamaChatModel, while maintaining full API backward compatibility thanks to the LegacyToolCallingManager.
* Improved efficiency and robustness of merging options in prompts for Ollama.
* Update Ollama Autoconfiguration to use the new ToolCallingManager.
* Improved troubleshooting for new tool calling APIs and finalised changes for full backward compatibility.
* Updated Ollama Testcontainers dependency to 0.5.7.

Relates to spring-projectsgh-2049

Signed-off-by: Thomas Vitale <ThomasVitale@users.noreply.github.com>
  • Loading branch information
ThomasVitale authored and tzolov committed Jan 30, 2025
1 parent 76ab91f commit b902ca2
Show file tree
Hide file tree
Showing 31 changed files with 1,269 additions and 151 deletions.

Large diffs are not rendered by default.

Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/*
* Copyright 2023-2024 the original author or authors.
* Copyright 2023-2025 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
Expand All @@ -17,6 +17,7 @@
package org.springframework.ai.ollama.api;

import java.util.ArrayList;
import java.util.Arrays;
import java.util.HashSet;
import java.util.List;
import java.util.Map;
Expand All @@ -32,7 +33,8 @@
import org.springframework.ai.embedding.EmbeddingOptions;
import org.springframework.ai.model.ModelOptionsUtils;
import org.springframework.ai.model.function.FunctionCallback;
import org.springframework.ai.model.function.FunctionCallingOptions;
import org.springframework.ai.model.tool.ToolCallingChatOptions;
import org.springframework.lang.Nullable;
import org.springframework.util.Assert;

/**
Expand All @@ -48,7 +50,7 @@
* @see <a href="https://github.com/ollama/ollama/blob/main/api/types.go">Ollama Types</a>
*/
@JsonInclude(Include.NON_NULL)
public class OllamaOptions implements FunctionCallingOptions, EmbeddingOptions {
public class OllamaOptions implements ToolCallingChatOptions, EmbeddingOptions {

private static final List<String> NON_SUPPORTED_FIELDS = List.of("model", "format", "keep_alive", "truncate");

Expand Down Expand Up @@ -305,28 +307,28 @@ public class OllamaOptions implements FunctionCallingOptions, EmbeddingOptions {
@JsonProperty("truncate")
private Boolean truncate;

@JsonIgnore
private Boolean internalToolExecutionEnabled;

/**
* Tool Function Callbacks to register with the ChatModel.
* For Prompt Options the functionCallbacks are automatically enabled for the duration of the prompt execution.
* For Default Options the functionCallbacks are registered but disabled by default. Use the enableFunctions to set the functions
* from the registry to be used by the ChatModel chat completion requests.
*/
@JsonIgnore
private List<FunctionCallback> functionCallbacks = new ArrayList<>();
private List<FunctionCallback> toolCallbacks = new ArrayList<>();

/**
* List of functions, identified by their names, to configure for function calling in
* the chat completion requests.
* Functions with those names must exist in the functionCallbacks registry.
* The {@link #functionCallbacks} from the PromptOptions are automatically enabled for the duration of the prompt execution.
* The {@link #toolCallbacks} from the PromptOptions are automatically enabled for the duration of the prompt execution.
* Note that function enabled with the default options are enabled for all chat completion requests. This could impact the token count and the billing.
* If the functions is set in a prompt options, then the enabled functions are only active for the duration of this prompt execution.
*/
@JsonIgnore
private Set<String> functions = new HashSet<>();

@JsonIgnore
private Boolean proxyToolCalls;
private Set<String> toolNames = new HashSet<>();

@JsonIgnore
private Map<String, Object> toolContext;
Expand Down Expand Up @@ -381,9 +383,9 @@ public static OllamaOptions fromOptions(OllamaOptions fromOptions) {
.mirostatEta(fromOptions.getMirostatEta())
.penalizeNewline(fromOptions.getPenalizeNewline())
.stop(fromOptions.getStop())
.functions(fromOptions.getFunctions())
.proxyToolCalls(fromOptions.getProxyToolCalls())
.functionCallbacks(fromOptions.getFunctionCallbacks())
.tools(fromOptions.getTools())
.internalToolExecutionEnabled(fromOptions.isInternalToolExecutionEnabled())
.toolCallbacks(fromOptions.getToolCallbacks())
.toolContext(fromOptions.getToolContext()).build();
}

Expand Down Expand Up @@ -683,23 +685,73 @@ public void setTruncate(Boolean truncate) {
}

@Override
@JsonIgnore
public List<FunctionCallback> getToolCallbacks() {
return this.toolCallbacks;
}

@Override
@JsonIgnore
public void setToolCallbacks(List<FunctionCallback> toolCallbacks) {
Assert.notNull(toolCallbacks, "toolCallbacks cannot be null");
Assert.noNullElements(toolCallbacks, "toolCallbacks cannot contain null elements");
this.toolCallbacks = toolCallbacks;
}

@Override
@JsonIgnore
public Set<String> getTools() {
return this.toolNames;
}

@Override
@JsonIgnore
public void setTools(Set<String> toolNames) {
Assert.notNull(toolNames, "toolNames cannot be null");
Assert.noNullElements(toolNames, "toolNames cannot contain null elements");
toolNames.forEach(tool -> Assert.hasText(tool, "toolNames cannot contain empty elements"));
this.toolNames = toolNames;
}

@Override
@Nullable
@JsonIgnore
public Boolean isInternalToolExecutionEnabled() {
return internalToolExecutionEnabled;
}

@Override
@JsonIgnore
public void setInternalToolExecutionEnabled(@Nullable Boolean internalToolExecutionEnabled) {
this.internalToolExecutionEnabled = internalToolExecutionEnabled;
}

@Override
@Deprecated
@JsonIgnore
public List<FunctionCallback> getFunctionCallbacks() {
return this.functionCallbacks;
return this.getToolCallbacks();
}

@Override
@Deprecated
@JsonIgnore
public void setFunctionCallbacks(List<FunctionCallback> functionCallbacks) {
this.functionCallbacks = functionCallbacks;
this.setToolCallbacks(functionCallbacks);
}

@Override
@Deprecated
@JsonIgnore
public Set<String> getFunctions() {
return this.functions;
return this.getTools();
}

@Override
@Deprecated
@JsonIgnore
public void setFunctions(Set<String> functions) {
this.functions = functions;
this.setTools(functions);
}

@Override
Expand All @@ -709,20 +761,26 @@ public Integer getDimensions() {
}

@Override
@Deprecated
@JsonIgnore
public Boolean getProxyToolCalls() {
return this.proxyToolCalls;
return this.internalToolExecutionEnabled != null ? !this.internalToolExecutionEnabled : null;
}

@Deprecated
@JsonIgnore
public void setProxyToolCalls(Boolean proxyToolCalls) {
this.proxyToolCalls = proxyToolCalls;
this.internalToolExecutionEnabled = proxyToolCalls != null ? !proxyToolCalls : null;
}

@Override
@JsonIgnore
public Map<String, Object> getToolContext() {
return this.toolContext;
}

@Override
@JsonIgnore
public void setToolContext(Map<String, Object> toolContext) {
this.toolContext = toolContext;
}
Expand Down Expand Up @@ -769,9 +827,9 @@ public boolean equals(Object o) {
&& Objects.equals(this.mirostat, that.mirostat) && Objects.equals(this.mirostatTau, that.mirostatTau)
&& Objects.equals(this.mirostatEta, that.mirostatEta)
&& Objects.equals(this.penalizeNewline, that.penalizeNewline) && Objects.equals(this.stop, that.stop)
&& Objects.equals(this.functionCallbacks, that.functionCallbacks)
&& Objects.equals(this.proxyToolCalls, that.proxyToolCalls)
&& Objects.equals(this.functions, that.functions) && Objects.equals(this.toolContext, that.toolContext);
&& Objects.equals(this.toolCallbacks, that.toolCallbacks)
&& Objects.equals(this.internalToolExecutionEnabled, that.internalToolExecutionEnabled)
&& Objects.equals(this.toolNames, that.toolNames) && Objects.equals(this.toolContext, that.toolContext);
}

@Override
Expand All @@ -781,7 +839,7 @@ public int hashCode() {
this.useMMap, this.useMLock, this.numThread, this.numKeep, this.seed, this.numPredict, this.topK,
this.topP, this.tfsZ, this.typicalP, this.repeatLastN, this.temperature, this.repeatPenalty,
this.presencePenalty, this.frequencyPenalty, this.mirostat, this.mirostatTau, this.mirostatEta,
this.penalizeNewline, this.stop, this.functionCallbacks, this.functions, this.proxyToolCalls,
this.penalizeNewline, this.stop, this.toolCallbacks, this.toolNames, this.internalToolExecutionEnabled,
this.toolContext);
}

Expand Down Expand Up @@ -959,25 +1017,53 @@ public Builder stop(List<String> stop) {
return this;
}

public Builder functionCallbacks(List<FunctionCallback> functionCallbacks) {
this.options.functionCallbacks = functionCallbacks;
public Builder toolCallbacks(List<FunctionCallback> toolCallbacks) {
this.options.setToolCallbacks(toolCallbacks);
return this;
}

public Builder functions(Set<String> functions) {
Assert.notNull(functions, "Function names must not be null");
this.options.functions = functions;
public Builder toolCallbacks(FunctionCallback... toolCallbacks) {
Assert.notNull(toolCallbacks, "toolCallbacks cannot be null");
this.options.toolCallbacks.addAll(Arrays.asList(toolCallbacks));
return this;
}

public Builder function(String functionName) {
Assert.hasText(functionName, "Function name must not be empty");
this.options.functions.add(functionName);
public Builder tools(Set<String> toolNames) {
this.options.setTools(toolNames);
return this;
}

public Builder tools(String... toolNames) {
Assert.notNull(toolNames, "toolNames cannot be null");
this.options.toolNames.addAll(Set.of(toolNames));
return this;
}

public Builder internalToolExecutionEnabled(@Nullable Boolean internalToolExecutionEnabled) {
this.options.setInternalToolExecutionEnabled(internalToolExecutionEnabled);
return this;
}

@Deprecated
public Builder functionCallbacks(List<FunctionCallback> functionCallbacks) {
return toolCallbacks(functionCallbacks);
}

@Deprecated
public Builder functions(Set<String> functions) {
return tools(functions);
}

@Deprecated
public Builder function(String functionName) {
return tools(functionName);
}

@Deprecated
public Builder proxyToolCalls(Boolean proxyToolCalls) {
this.options.proxyToolCalls = proxyToolCalls;
if (proxyToolCalls != null) {
this.options.setInternalToolExecutionEnabled(!proxyToolCalls);
}
return this;
}

Expand Down
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/*
* Copyright 2023-2024 the original author or authors.
* Copyright 2023-2025 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
Expand Down Expand Up @@ -29,17 +29,18 @@
* @author Christian Tzolov
* @author Thomas Vitale
*/
public class OllamaChatRequestTests {
class OllamaChatRequestTests {

OllamaChatModel chatModel = OllamaChatModel.builder()
.ollamaApi(new OllamaApi())
.defaultOptions(OllamaOptions.builder().model("MODEL_NAME").topK(99).temperature(66.6).numGPU(1).build())
.build();

@Test
public void createRequestWithDefaultOptions() {
void createRequestWithDefaultOptions() {
var prompt = this.chatModel.buildRequestPrompt(new Prompt("Test message content"));

var request = this.chatModel.ollamaChatRequest(new Prompt("Test message content"), false);
var request = this.chatModel.ollamaChatRequest(prompt, false);

assertThat(request.messages()).hasSize(1);
assertThat(request.stream()).isFalse();
Expand All @@ -52,12 +53,12 @@ public void createRequestWithDefaultOptions() {
}

@Test
public void createRequestWithPromptOllamaOptions() {

void createRequestWithPromptOllamaOptions() {
// Runtime options should override the default options.
OllamaOptions promptOptions = OllamaOptions.builder().temperature(0.8).topP(0.5).numGPU(2).build();
var prompt = this.chatModel.buildRequestPrompt(new Prompt("Test message content", promptOptions));

var request = this.chatModel.ollamaChatRequest(new Prompt("Test message content", promptOptions), true);
var request = this.chatModel.ollamaChatRequest(prompt, true);

assertThat(request.messages()).hasSize(1);
assertThat(request.stream()).isTrue();
Expand All @@ -74,11 +75,11 @@ public void createRequestWithPromptOllamaOptions() {

@Test
public void createRequestWithPromptPortableChatOptions() {

// Ollama runtime options.
ChatOptions portablePromptOptions = ChatOptions.builder().temperature(0.9).topK(100).topP(0.6).build();
var prompt = this.chatModel.buildRequestPrompt(new Prompt("Test message content", portablePromptOptions));

var request = this.chatModel.ollamaChatRequest(new Prompt("Test message content", portablePromptOptions), true);
var request = this.chatModel.ollamaChatRequest(prompt, true);

assertThat(request.messages()).hasSize(1);
assertThat(request.stream()).isTrue();
Expand All @@ -92,31 +93,33 @@ public void createRequestWithPromptPortableChatOptions() {

@Test
public void createRequestWithPromptOptionsModelOverride() {

// Ollama runtime options.
OllamaOptions promptOptions = OllamaOptions.builder().model("PROMPT_MODEL").build();
var prompt = this.chatModel.buildRequestPrompt(new Prompt("Test message content", promptOptions));

var request = this.chatModel.ollamaChatRequest(new Prompt("Test message content", promptOptions), true);
var request = this.chatModel.ollamaChatRequest(prompt, true);

assertThat(request.model()).isEqualTo("PROMPT_MODEL");
}

@Test
public void createRequestWithDefaultOptionsModelOverride() {

OllamaChatModel chatModel = OllamaChatModel.builder()
.ollamaApi(new OllamaApi())
.defaultOptions(OllamaOptions.builder().model("DEFAULT_OPTIONS_MODEL").build())
.build();

var request = chatModel.ollamaChatRequest(new Prompt("Test message content"), true);
var prompt1 = chatModel.buildRequestPrompt(new Prompt("Test message content"));

var request = chatModel.ollamaChatRequest(prompt1, true);

assertThat(request.model()).isEqualTo("DEFAULT_OPTIONS_MODEL");

// Prompt options should override the default options.
OllamaOptions promptOptions = OllamaOptions.builder().model("PROMPT_MODEL").build();
var prompt2 = chatModel.buildRequestPrompt(new Prompt("Test message content", promptOptions));

request = chatModel.ollamaChatRequest(new Prompt("Test message content", promptOptions), true);
request = chatModel.ollamaChatRequest(prompt2, true);

assertThat(request.model()).isEqualTo("PROMPT_MODEL");
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -76,7 +76,7 @@ class OllamaWithOpenAiChatModelIT {
private static final String DEFAULT_OLLAMA_MODEL = "mistral";

@Container
static OllamaContainer ollamaContainer = new OllamaContainer("ollama/ollama:0.5.1");
static OllamaContainer ollamaContainer = new OllamaContainer("ollama/ollama:0.5.7");

static String baseUrl = "http://localhost:11434";

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -290,8 +290,7 @@ interface Builder {
Builder defaultFunctions(String... functionNames);

/**
* @deprecated in favor of {@link #defaultTools(FunctionCallback...)} or
* {@link #defaultToolCallbacks(FunctionCallback...)}
* @deprecated in favor of {@link #defaultTools(Object...)}
*/
@Deprecated
Builder defaultFunctions(FunctionCallback... functionCallbacks);
Expand Down
Loading

0 comments on commit b902ca2

Please sign in to comment.