diff --git a/models/spring-ai-ollama/src/main/java/org/springframework/ai/ollama/OllamaChatModel.java b/models/spring-ai-ollama/src/main/java/org/springframework/ai/ollama/OllamaChatModel.java index 3329462bee..12eef97525 100644 --- a/models/spring-ai-ollama/src/main/java/org/springframework/ai/ollama/OllamaChatModel.java +++ b/models/spring-ai-ollama/src/main/java/org/springframework/ai/ollama/OllamaChatModel.java @@ -1,5 +1,5 @@ /* - * Copyright 2023-2024 the original author or authors. + * Copyright 2023-2025 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -18,15 +18,21 @@ import java.time.Duration; import java.util.Base64; -import java.util.HashSet; import java.util.List; import java.util.Map; import java.util.Optional; -import java.util.Set; +import com.fasterxml.jackson.core.type.TypeReference; import io.micrometer.observation.Observation; import io.micrometer.observation.ObservationRegistry; import io.micrometer.observation.contextpropagation.ObservationThreadLocalAccessor; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; +import org.springframework.ai.model.tool.LegacyToolCallingManager; +import org.springframework.ai.model.tool.ToolCallingChatOptions; +import org.springframework.ai.model.tool.ToolCallingManager; +import org.springframework.ai.tool.definition.ToolDefinition; +import org.springframework.ai.util.json.JsonParser; import reactor.core.publisher.Flux; import org.springframework.ai.chat.messages.AssistantMessage; @@ -82,6 +88,8 @@ */ public class OllamaChatModel extends AbstractToolCallSupport implements ChatModel { + private static final Logger logger = LoggerFactory.getLogger(OllamaChatModel.class); + private static final String DONE = "done"; private static final String METADATA_PROMPT_EVAL_COUNT = "prompt-eval-count"; @@ -100,6 +108,8 @@ public class OllamaChatModel extends AbstractToolCallSupport implements ChatMode private static final ChatModelObservationConvention DEFAULT_OBSERVATION_CONVENTION = new DefaultChatModelObservationConvention(); + private static final ToolCallingManager DEFAULT_TOOL_CALLING_MANAGER = ToolCallingManager.builder().build(); + private final OllamaApi chatApi; private final OllamaOptions defaultOptions; @@ -108,8 +118,11 @@ public class OllamaChatModel extends AbstractToolCallSupport implements ChatMode private final OllamaModelManager modelManager; + private final ToolCallingManager toolCallingManager; + private ChatModelObservationConvention observationConvention = DEFAULT_OBSERVATION_CONVENTION; + @Deprecated public OllamaChatModel(OllamaApi ollamaApi, OllamaOptions defaultOptions, FunctionCallbackResolver functionCallbackResolver, List toolFunctionCallbacks, ObservationRegistry observationRegistry, ModelManagementOptions modelManagementOptions) { @@ -120,6 +133,26 @@ public OllamaChatModel(OllamaApi ollamaApi, OllamaOptions defaultOptions, Assert.notNull(modelManagementOptions, "modelManagementOptions must not be null"); this.chatApi = ollamaApi; this.defaultOptions = defaultOptions; + this.toolCallingManager = new LegacyToolCallingManager(functionCallbackResolver, toolFunctionCallbacks); + this.observationRegistry = observationRegistry; + this.modelManager = new OllamaModelManager(this.chatApi, modelManagementOptions); + initializeModel(defaultOptions.getModel(), modelManagementOptions.pullModelStrategy()); + + logger.warn("This constructor is deprecated and will be removed in the next milestone. " + + "Please use the new constructor accepting ToolCallingManager instead."); + } + + public OllamaChatModel(OllamaApi ollamaApi, OllamaOptions defaultOptions, ToolCallingManager toolCallingManager, + ObservationRegistry observationRegistry, ModelManagementOptions modelManagementOptions) { + super(null, defaultOptions, List.of()); + Assert.notNull(ollamaApi, "ollamaApi must not be null"); + Assert.notNull(defaultOptions, "defaultOptions must not be null"); + Assert.notNull(toolCallingManager, "toolCallingManager must not be null"); + Assert.notNull(observationRegistry, "observationRegistry must not be null"); + Assert.notNull(modelManagementOptions, "modelManagementOptions must not be null"); + this.chatApi = ollamaApi; + this.defaultOptions = defaultOptions; + this.toolCallingManager = toolCallingManager; this.observationRegistry = observationRegistry; this.modelManager = new OllamaModelManager(this.chatApi, modelManagementOptions); initializeModel(defaultOptions.getModel(), modelManagementOptions.pullModelStrategy()); @@ -186,7 +219,10 @@ private static DefaultUsage getDefaultUsage(OllamaApi.ChatResponse response) { @Override public ChatResponse call(Prompt prompt) { - return this.internalCall(prompt, null); + // Before moving any further, build the final request Prompt, + // merging runtime and default options. + Prompt requestPrompt = buildRequestPrompt(prompt); + return this.internalCall(requestPrompt, null); } private ChatResponse internalCall(Prompt prompt, ChatResponse previousChatResponse) { @@ -196,7 +232,7 @@ private ChatResponse internalCall(Prompt prompt, ChatResponse previousChatRespon ChatModelObservationContext observationContext = ChatModelObservationContext.builder() .prompt(prompt) .provider(OllamaApi.PROVIDER_NAME) - .requestOptions(buildRequestOptions(request)) + .requestOptions(prompt.getOptions()) .build(); ChatResponse response = ChatModelObservationDocumentation.CHAT_MODEL_OPERATION @@ -233,9 +269,9 @@ private ChatResponse internalCall(Prompt prompt, ChatResponse previousChatRespon }); - if (!isProxyToolCalls(prompt, this.defaultOptions) && response != null - && isToolCall(response, Set.of("stop"))) { - var toolCallConversation = handleToolCalls(prompt, response); + if (ToolCallingChatOptions.isInternalToolExecutionEnabled(prompt.getOptions()) && response != null + && response.hasToolCalls()) { + var toolCallConversation = this.toolCallingManager.executeToolCalls(prompt, response); // Recursively call the call method with the tool call message // conversation that contains the call responses. return this.internalCall(new Prompt(toolCallConversation, prompt.getOptions()), response); @@ -246,7 +282,10 @@ && isToolCall(response, Set.of("stop"))) { @Override public Flux stream(Prompt prompt) { - return this.internalStream(prompt, null); + // Before moving any further, build the final request Prompt, + // merging runtime and default options. + Prompt requestPrompt = buildRequestPrompt(prompt); + return this.internalStream(requestPrompt, null); } private Flux internalStream(Prompt prompt, ChatResponse previousChatResponse) { @@ -256,7 +295,7 @@ private Flux internalStream(Prompt prompt, ChatResponse previousCh final ChatModelObservationContext observationContext = ChatModelObservationContext.builder() .prompt(prompt) .provider(OllamaApi.PROVIDER_NAME) - .requestOptions(buildRequestOptions(request)) + .requestOptions(prompt.getOptions()) .build(); Observation observation = ChatModelObservationDocumentation.CHAT_MODEL_OPERATION.observation( @@ -295,8 +334,8 @@ private Flux internalStream(Prompt prompt, ChatResponse previousCh // @formatter:off Flux chatResponseFlux = chatResponse.flatMap(response -> { - if (isToolCall(response, Set.of("stop"))) { - var toolCallConversation = handleToolCalls(prompt, response); + if (ToolCallingChatOptions.isInternalToolExecutionEnabled(prompt.getOptions()) && response.hasToolCalls()) { + var toolCallConversation = this.toolCallingManager.executeToolCalls(prompt, response); // Recursively call the stream method with the tool call message // conversation that contains the call responses. return this.internalStream(new Prompt(toolCallConversation, prompt.getOptions()), response); @@ -316,6 +355,48 @@ private Flux internalStream(Prompt prompt, ChatResponse previousCh }); } + Prompt buildRequestPrompt(Prompt prompt) { + // Process runtime options + OllamaOptions runtimeOptions = null; + if (prompt.getOptions() != null) { + if (prompt.getOptions() instanceof ToolCallingChatOptions toolCallingChatOptions) { + runtimeOptions = ModelOptionsUtils.copyToTarget(toolCallingChatOptions, ToolCallingChatOptions.class, + OllamaOptions.class); + } + else if (prompt.getOptions() instanceof FunctionCallingOptions functionCallingOptions) { + runtimeOptions = ModelOptionsUtils.copyToTarget(functionCallingOptions, FunctionCallingOptions.class, + OllamaOptions.class); + } + else { + runtimeOptions = ModelOptionsUtils.copyToTarget(prompt.getOptions(), ChatOptions.class, + OllamaOptions.class); + } + } + + // Define request options by merging runtime options and default options + OllamaOptions requestOptions = ModelOptionsUtils.merge(runtimeOptions, this.defaultOptions, + OllamaOptions.class); + // Merge tool names and tool callbacks explicitly since they are ignored by + // Jackson, used by ModelOptionsUtils. + if (runtimeOptions != null) { + requestOptions.setTools( + ToolCallingChatOptions.mergeToolNames(runtimeOptions.getTools(), this.defaultOptions.getTools())); + requestOptions.setToolCallbacks(ToolCallingChatOptions.mergeToolCallbacks(runtimeOptions.getToolCallbacks(), + this.defaultOptions.getToolCallbacks())); + } + else { + requestOptions.setTools(this.defaultOptions.getTools()); + requestOptions.setToolCallbacks(this.defaultOptions.getToolCallbacks()); + } + + // Validate request options + if (!StringUtils.hasText(requestOptions.getModel())) { + throw new IllegalArgumentException("model cannot be null or empty"); + } + + return new Prompt(prompt.getInstructions(), requestOptions); + } + /** * Package access for testing. */ @@ -338,7 +419,8 @@ else if (message instanceof AssistantMessage assistantMessage) { if (!CollectionUtils.isEmpty(assistantMessage.getToolCalls())) { toolCalls = assistantMessage.getToolCalls().stream().map(toolCall -> { var function = new ToolCallFunction(toolCall.name(), - ModelOptionsUtils.jsonToMap(toolCall.arguments())); + JsonParser.fromJson(toolCall.arguments(), new TypeReference<>() { + })); return new ToolCall(function); }).toList(); } @@ -356,49 +438,24 @@ else if (message instanceof ToolResponseMessage toolMessage) { throw new IllegalArgumentException("Unsupported message type: " + message.getMessageType()); }).flatMap(List::stream).toList(); - Set functionsForThisRequest = new HashSet<>(); - - // runtime options - OllamaOptions runtimeOptions = null; - if (prompt.getOptions() != null) { - if (prompt.getOptions() instanceof FunctionCallingOptions functionCallingOptions) { - runtimeOptions = ModelOptionsUtils.copyToTarget(functionCallingOptions, FunctionCallingOptions.class, - OllamaOptions.class); - } - else { - runtimeOptions = ModelOptionsUtils.copyToTarget(prompt.getOptions(), ChatOptions.class, - OllamaOptions.class); - } - functionsForThisRequest.addAll(this.runtimeFunctionCallbackConfigurations(runtimeOptions)); - } + OllamaOptions requestOptions = (OllamaOptions) prompt.getOptions(); - if (!CollectionUtils.isEmpty(this.defaultOptions.getFunctions())) { - functionsForThisRequest.addAll(this.defaultOptions.getFunctions()); - } - OllamaOptions mergedOptions = ModelOptionsUtils.merge(runtimeOptions, this.defaultOptions, OllamaOptions.class); - - // Override the model. - if (!StringUtils.hasText(mergedOptions.getModel())) { - throw new IllegalArgumentException("Model is not set!"); - } - - String model = mergedOptions.getModel(); - OllamaApi.ChatRequest.Builder requestBuilder = OllamaApi.ChatRequest.builder(model) + OllamaApi.ChatRequest.Builder requestBuilder = OllamaApi.ChatRequest.builder(requestOptions.getModel()) .stream(stream) .messages(ollamaMessages) - .options(mergedOptions); + .options(requestOptions); - if (mergedOptions.getFormat() != null) { - requestBuilder.format(mergedOptions.getFormat()); + if (requestOptions.getFormat() != null) { + requestBuilder.format(requestOptions.getFormat()); } - if (mergedOptions.getKeepAlive() != null) { - requestBuilder.keepAlive(mergedOptions.getKeepAlive()); + if (requestOptions.getKeepAlive() != null) { + requestBuilder.keepAlive(requestOptions.getKeepAlive()); } - // Add the enabled functions definitions to the request's tools parameter. - if (!CollectionUtils.isEmpty(functionsForThisRequest)) { - requestBuilder.tools(this.getFunctionTools(functionsForThisRequest)); + List toolDefinitions = this.toolCallingManager.resolveToolDefinitions(requestOptions); + if (!CollectionUtils.isEmpty(toolDefinitions)) { + requestBuilder.tools(this.getTools(toolDefinitions)); } return requestBuilder.build(); @@ -417,28 +474,14 @@ else if (mediaData instanceof String text) { } - private List getFunctionTools(Set functionNames) { - return this.resolveFunctionCallbacks(functionNames).stream().map(functionCallback -> { - var function = new ChatRequest.Tool.Function(functionCallback.getName(), functionCallback.getDescription(), - functionCallback.getInputTypeSchema()); - return new ChatRequest.Tool(function); + private List getTools(List toolDefinitions) { + return toolDefinitions.stream().map(toolDefinition -> { + var tool = new ChatRequest.Tool.Function(toolDefinition.name(), toolDefinition.description(), + toolDefinition.inputSchema()); + return new ChatRequest.Tool(tool); }).toList(); } - private ChatOptions buildRequestOptions(OllamaApi.ChatRequest request) { - var options = ModelOptionsUtils.mapToClass(request.options(), OllamaOptions.class); - return ChatOptions.builder() - .model(request.model()) - .frequencyPenalty(options.getFrequencyPenalty()) - .maxTokens(options.getMaxTokens()) - .presencePenalty(options.getPresencePenalty()) - .stopSequences(options.getStopSequences()) - .temperature(options.getTemperature()) - .topK(options.getTopK()) - .topP(options.getTopP()) - .build(); - } - @Override public ChatOptions getDefaultOptions() { return OllamaOptions.fromOptions(this.defaultOptions); @@ -468,9 +511,11 @@ public static final class Builder { private OllamaOptions defaultOptions = OllamaOptions.builder().model(OllamaModel.MISTRAL.id()).build(); + private ToolCallingManager toolCallingManager; + private FunctionCallbackResolver functionCallbackResolver; - private List toolFunctionCallbacks = List.of(); + private List toolFunctionCallbacks; private ObservationRegistry observationRegistry = ObservationRegistry.NOOP; @@ -489,11 +534,18 @@ public Builder defaultOptions(OllamaOptions defaultOptions) { return this; } + public Builder toolCallingManager(ToolCallingManager toolCallingManager) { + this.toolCallingManager = toolCallingManager; + return this; + } + + @Deprecated public Builder functionCallbackResolver(FunctionCallbackResolver functionCallbackResolver) { this.functionCallbackResolver = functionCallbackResolver; return this; } + @Deprecated public Builder toolFunctionCallbacks(List toolFunctionCallbacks) { this.toolFunctionCallbacks = toolFunctionCallbacks; return this; @@ -510,8 +562,27 @@ public Builder modelManagementOptions(ModelManagementOptions modelManagementOpti } public OllamaChatModel build() { - return new OllamaChatModel(this.ollamaApi, this.defaultOptions, this.functionCallbackResolver, - this.toolFunctionCallbacks, this.observationRegistry, this.modelManagementOptions); + if (toolCallingManager != null) { + Assert.isNull(functionCallbackResolver, + "functionCallbackResolver must not be set when toolCallingManager is set"); + Assert.isNull(toolFunctionCallbacks, + "toolFunctionCallbacks must not be set when toolCallingManager is set"); + + return new OllamaChatModel(this.ollamaApi, this.defaultOptions, this.toolCallingManager, + this.observationRegistry, this.modelManagementOptions); + } + + if (functionCallbackResolver != null) { + Assert.isNull(toolCallingManager, + "toolCallingManager must not be set when functionCallbackResolver is set"); + List toolCallbacks = this.toolFunctionCallbacks != null ? this.toolFunctionCallbacks + : List.of(); + return new OllamaChatModel(this.ollamaApi, this.defaultOptions, this.functionCallbackResolver, + toolCallbacks, this.observationRegistry, this.modelManagementOptions); + } + + return new OllamaChatModel(this.ollamaApi, this.defaultOptions, DEFAULT_TOOL_CALLING_MANAGER, + this.observationRegistry, this.modelManagementOptions); } } diff --git a/models/spring-ai-ollama/src/main/java/org/springframework/ai/ollama/api/OllamaOptions.java b/models/spring-ai-ollama/src/main/java/org/springframework/ai/ollama/api/OllamaOptions.java index dcc8cc50be..af36cd3941 100644 --- a/models/spring-ai-ollama/src/main/java/org/springframework/ai/ollama/api/OllamaOptions.java +++ b/models/spring-ai-ollama/src/main/java/org/springframework/ai/ollama/api/OllamaOptions.java @@ -1,5 +1,5 @@ /* - * Copyright 2023-2024 the original author or authors. + * Copyright 2023-2025 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -17,6 +17,7 @@ package org.springframework.ai.ollama.api; import java.util.ArrayList; +import java.util.Arrays; import java.util.HashSet; import java.util.List; import java.util.Map; @@ -32,7 +33,8 @@ import org.springframework.ai.embedding.EmbeddingOptions; import org.springframework.ai.model.ModelOptionsUtils; import org.springframework.ai.model.function.FunctionCallback; -import org.springframework.ai.model.function.FunctionCallingOptions; +import org.springframework.ai.model.tool.ToolCallingChatOptions; +import org.springframework.lang.Nullable; import org.springframework.util.Assert; /** @@ -48,7 +50,7 @@ * @see Ollama Types */ @JsonInclude(Include.NON_NULL) -public class OllamaOptions implements FunctionCallingOptions, EmbeddingOptions { +public class OllamaOptions implements ToolCallingChatOptions, EmbeddingOptions { private static final List NON_SUPPORTED_FIELDS = List.of("model", "format", "keep_alive", "truncate"); @@ -305,6 +307,9 @@ public class OllamaOptions implements FunctionCallingOptions, EmbeddingOptions { @JsonProperty("truncate") private Boolean truncate; + @JsonIgnore + private Boolean internalToolExecutionEnabled; + /** * Tool Function Callbacks to register with the ChatModel. * For Prompt Options the functionCallbacks are automatically enabled for the duration of the prompt execution. @@ -312,21 +317,18 @@ public class OllamaOptions implements FunctionCallingOptions, EmbeddingOptions { * from the registry to be used by the ChatModel chat completion requests. */ @JsonIgnore - private List functionCallbacks = new ArrayList<>(); + private List toolCallbacks = new ArrayList<>(); /** * List of functions, identified by their names, to configure for function calling in * the chat completion requests. * Functions with those names must exist in the functionCallbacks registry. - * The {@link #functionCallbacks} from the PromptOptions are automatically enabled for the duration of the prompt execution. + * The {@link #toolCallbacks} from the PromptOptions are automatically enabled for the duration of the prompt execution. * Note that function enabled with the default options are enabled for all chat completion requests. This could impact the token count and the billing. * If the functions is set in a prompt options, then the enabled functions are only active for the duration of this prompt execution. */ @JsonIgnore - private Set functions = new HashSet<>(); - - @JsonIgnore - private Boolean proxyToolCalls; + private Set toolNames = new HashSet<>(); @JsonIgnore private Map toolContext; @@ -381,9 +383,9 @@ public static OllamaOptions fromOptions(OllamaOptions fromOptions) { .mirostatEta(fromOptions.getMirostatEta()) .penalizeNewline(fromOptions.getPenalizeNewline()) .stop(fromOptions.getStop()) - .functions(fromOptions.getFunctions()) - .proxyToolCalls(fromOptions.getProxyToolCalls()) - .functionCallbacks(fromOptions.getFunctionCallbacks()) + .tools(fromOptions.getTools()) + .internalToolExecutionEnabled(fromOptions.isInternalToolExecutionEnabled()) + .toolCallbacks(fromOptions.getToolCallbacks()) .toolContext(fromOptions.getToolContext()).build(); } @@ -683,23 +685,73 @@ public void setTruncate(Boolean truncate) { } @Override + @JsonIgnore + public List getToolCallbacks() { + return this.toolCallbacks; + } + + @Override + @JsonIgnore + public void setToolCallbacks(List toolCallbacks) { + Assert.notNull(toolCallbacks, "toolCallbacks cannot be null"); + Assert.noNullElements(toolCallbacks, "toolCallbacks cannot contain null elements"); + this.toolCallbacks = toolCallbacks; + } + + @Override + @JsonIgnore + public Set getTools() { + return this.toolNames; + } + + @Override + @JsonIgnore + public void setTools(Set toolNames) { + Assert.notNull(toolNames, "toolNames cannot be null"); + Assert.noNullElements(toolNames, "toolNames cannot contain null elements"); + toolNames.forEach(tool -> Assert.hasText(tool, "toolNames cannot contain empty elements")); + this.toolNames = toolNames; + } + + @Override + @Nullable + @JsonIgnore + public Boolean isInternalToolExecutionEnabled() { + return internalToolExecutionEnabled; + } + + @Override + @JsonIgnore + public void setInternalToolExecutionEnabled(@Nullable Boolean internalToolExecutionEnabled) { + this.internalToolExecutionEnabled = internalToolExecutionEnabled; + } + + @Override + @Deprecated + @JsonIgnore public List getFunctionCallbacks() { - return this.functionCallbacks; + return this.getToolCallbacks(); } @Override + @Deprecated + @JsonIgnore public void setFunctionCallbacks(List functionCallbacks) { - this.functionCallbacks = functionCallbacks; + this.setToolCallbacks(functionCallbacks); } @Override + @Deprecated + @JsonIgnore public Set getFunctions() { - return this.functions; + return this.getTools(); } @Override + @Deprecated + @JsonIgnore public void setFunctions(Set functions) { - this.functions = functions; + this.setTools(functions); } @Override @@ -709,20 +761,26 @@ public Integer getDimensions() { } @Override + @Deprecated + @JsonIgnore public Boolean getProxyToolCalls() { - return this.proxyToolCalls; + return this.internalToolExecutionEnabled != null ? !this.internalToolExecutionEnabled : null; } + @Deprecated + @JsonIgnore public void setProxyToolCalls(Boolean proxyToolCalls) { - this.proxyToolCalls = proxyToolCalls; + this.internalToolExecutionEnabled = proxyToolCalls != null ? !proxyToolCalls : null; } @Override + @JsonIgnore public Map getToolContext() { return this.toolContext; } @Override + @JsonIgnore public void setToolContext(Map toolContext) { this.toolContext = toolContext; } @@ -769,9 +827,9 @@ public boolean equals(Object o) { && Objects.equals(this.mirostat, that.mirostat) && Objects.equals(this.mirostatTau, that.mirostatTau) && Objects.equals(this.mirostatEta, that.mirostatEta) && Objects.equals(this.penalizeNewline, that.penalizeNewline) && Objects.equals(this.stop, that.stop) - && Objects.equals(this.functionCallbacks, that.functionCallbacks) - && Objects.equals(this.proxyToolCalls, that.proxyToolCalls) - && Objects.equals(this.functions, that.functions) && Objects.equals(this.toolContext, that.toolContext); + && Objects.equals(this.toolCallbacks, that.toolCallbacks) + && Objects.equals(this.internalToolExecutionEnabled, that.internalToolExecutionEnabled) + && Objects.equals(this.toolNames, that.toolNames) && Objects.equals(this.toolContext, that.toolContext); } @Override @@ -781,7 +839,7 @@ public int hashCode() { this.useMMap, this.useMLock, this.numThread, this.numKeep, this.seed, this.numPredict, this.topK, this.topP, this.tfsZ, this.typicalP, this.repeatLastN, this.temperature, this.repeatPenalty, this.presencePenalty, this.frequencyPenalty, this.mirostat, this.mirostatTau, this.mirostatEta, - this.penalizeNewline, this.stop, this.functionCallbacks, this.functions, this.proxyToolCalls, + this.penalizeNewline, this.stop, this.toolCallbacks, this.toolNames, this.internalToolExecutionEnabled, this.toolContext); } @@ -959,25 +1017,53 @@ public Builder stop(List stop) { return this; } - public Builder functionCallbacks(List functionCallbacks) { - this.options.functionCallbacks = functionCallbacks; + public Builder toolCallbacks(List toolCallbacks) { + this.options.setToolCallbacks(toolCallbacks); return this; } - public Builder functions(Set functions) { - Assert.notNull(functions, "Function names must not be null"); - this.options.functions = functions; + public Builder toolCallbacks(FunctionCallback... toolCallbacks) { + Assert.notNull(toolCallbacks, "toolCallbacks cannot be null"); + this.options.toolCallbacks.addAll(Arrays.asList(toolCallbacks)); return this; } - public Builder function(String functionName) { - Assert.hasText(functionName, "Function name must not be empty"); - this.options.functions.add(functionName); + public Builder tools(Set toolNames) { + this.options.setTools(toolNames); + return this; + } + + public Builder tools(String... toolNames) { + Assert.notNull(toolNames, "toolNames cannot be null"); + this.options.toolNames.addAll(Set.of(toolNames)); + return this; + } + + public Builder internalToolExecutionEnabled(@Nullable Boolean internalToolExecutionEnabled) { + this.options.setInternalToolExecutionEnabled(internalToolExecutionEnabled); return this; } + @Deprecated + public Builder functionCallbacks(List functionCallbacks) { + return toolCallbacks(functionCallbacks); + } + + @Deprecated + public Builder functions(Set functions) { + return tools(functions); + } + + @Deprecated + public Builder function(String functionName) { + return tools(functionName); + } + + @Deprecated public Builder proxyToolCalls(Boolean proxyToolCalls) { - this.options.proxyToolCalls = proxyToolCalls; + if (proxyToolCalls != null) { + this.options.setInternalToolExecutionEnabled(!proxyToolCalls); + } return this; } diff --git a/models/spring-ai-ollama/src/test/java/org/springframework/ai/ollama/OllamaChatRequestTests.java b/models/spring-ai-ollama/src/test/java/org/springframework/ai/ollama/OllamaChatRequestTests.java index 63e2c7933c..9f38c6fa06 100644 --- a/models/spring-ai-ollama/src/test/java/org/springframework/ai/ollama/OllamaChatRequestTests.java +++ b/models/spring-ai-ollama/src/test/java/org/springframework/ai/ollama/OllamaChatRequestTests.java @@ -1,5 +1,5 @@ /* - * Copyright 2023-2024 the original author or authors. + * Copyright 2023-2025 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -29,7 +29,7 @@ * @author Christian Tzolov * @author Thomas Vitale */ -public class OllamaChatRequestTests { +class OllamaChatRequestTests { OllamaChatModel chatModel = OllamaChatModel.builder() .ollamaApi(new OllamaApi()) @@ -37,9 +37,10 @@ public class OllamaChatRequestTests { .build(); @Test - public void createRequestWithDefaultOptions() { + void createRequestWithDefaultOptions() { + var prompt = this.chatModel.buildRequestPrompt(new Prompt("Test message content")); - var request = this.chatModel.ollamaChatRequest(new Prompt("Test message content"), false); + var request = this.chatModel.ollamaChatRequest(prompt, false); assertThat(request.messages()).hasSize(1); assertThat(request.stream()).isFalse(); @@ -52,12 +53,12 @@ public void createRequestWithDefaultOptions() { } @Test - public void createRequestWithPromptOllamaOptions() { - + void createRequestWithPromptOllamaOptions() { // Runtime options should override the default options. OllamaOptions promptOptions = OllamaOptions.builder().temperature(0.8).topP(0.5).numGPU(2).build(); + var prompt = this.chatModel.buildRequestPrompt(new Prompt("Test message content", promptOptions)); - var request = this.chatModel.ollamaChatRequest(new Prompt("Test message content", promptOptions), true); + var request = this.chatModel.ollamaChatRequest(prompt, true); assertThat(request.messages()).hasSize(1); assertThat(request.stream()).isTrue(); @@ -74,11 +75,11 @@ public void createRequestWithPromptOllamaOptions() { @Test public void createRequestWithPromptPortableChatOptions() { - // Ollama runtime options. ChatOptions portablePromptOptions = ChatOptions.builder().temperature(0.9).topK(100).topP(0.6).build(); + var prompt = this.chatModel.buildRequestPrompt(new Prompt("Test message content", portablePromptOptions)); - var request = this.chatModel.ollamaChatRequest(new Prompt("Test message content", portablePromptOptions), true); + var request = this.chatModel.ollamaChatRequest(prompt, true); assertThat(request.messages()).hasSize(1); assertThat(request.stream()).isTrue(); @@ -92,31 +93,33 @@ public void createRequestWithPromptPortableChatOptions() { @Test public void createRequestWithPromptOptionsModelOverride() { - // Ollama runtime options. OllamaOptions promptOptions = OllamaOptions.builder().model("PROMPT_MODEL").build(); + var prompt = this.chatModel.buildRequestPrompt(new Prompt("Test message content", promptOptions)); - var request = this.chatModel.ollamaChatRequest(new Prompt("Test message content", promptOptions), true); + var request = this.chatModel.ollamaChatRequest(prompt, true); assertThat(request.model()).isEqualTo("PROMPT_MODEL"); } @Test public void createRequestWithDefaultOptionsModelOverride() { - OllamaChatModel chatModel = OllamaChatModel.builder() .ollamaApi(new OllamaApi()) .defaultOptions(OllamaOptions.builder().model("DEFAULT_OPTIONS_MODEL").build()) .build(); - var request = chatModel.ollamaChatRequest(new Prompt("Test message content"), true); + var prompt1 = chatModel.buildRequestPrompt(new Prompt("Test message content")); + + var request = chatModel.ollamaChatRequest(prompt1, true); assertThat(request.model()).isEqualTo("DEFAULT_OPTIONS_MODEL"); // Prompt options should override the default options. OllamaOptions promptOptions = OllamaOptions.builder().model("PROMPT_MODEL").build(); + var prompt2 = chatModel.buildRequestPrompt(new Prompt("Test message content", promptOptions)); - request = chatModel.ollamaChatRequest(new Prompt("Test message content", promptOptions), true); + request = chatModel.ollamaChatRequest(prompt2, true); assertThat(request.model()).isEqualTo("PROMPT_MODEL"); } diff --git a/models/spring-ai-openai/src/test/java/org/springframework/ai/openai/chat/proxy/OllamaWithOpenAiChatModelIT.java b/models/spring-ai-openai/src/test/java/org/springframework/ai/openai/chat/proxy/OllamaWithOpenAiChatModelIT.java index c52bfd50df..c0587c03bd 100644 --- a/models/spring-ai-openai/src/test/java/org/springframework/ai/openai/chat/proxy/OllamaWithOpenAiChatModelIT.java +++ b/models/spring-ai-openai/src/test/java/org/springframework/ai/openai/chat/proxy/OllamaWithOpenAiChatModelIT.java @@ -76,7 +76,7 @@ class OllamaWithOpenAiChatModelIT { private static final String DEFAULT_OLLAMA_MODEL = "mistral"; @Container - static OllamaContainer ollamaContainer = new OllamaContainer("ollama/ollama:0.5.1"); + static OllamaContainer ollamaContainer = new OllamaContainer("ollama/ollama:0.5.7"); static String baseUrl = "http://localhost:11434"; diff --git a/spring-ai-core/src/main/java/org/springframework/ai/chat/client/ChatClient.java b/spring-ai-core/src/main/java/org/springframework/ai/chat/client/ChatClient.java index c90c5ed464..7fd359b10a 100644 --- a/spring-ai-core/src/main/java/org/springframework/ai/chat/client/ChatClient.java +++ b/spring-ai-core/src/main/java/org/springframework/ai/chat/client/ChatClient.java @@ -290,8 +290,7 @@ interface Builder { Builder defaultFunctions(String... functionNames); /** - * @deprecated in favor of {@link #defaultTools(FunctionCallback...)} or - * {@link #defaultToolCallbacks(FunctionCallback...)} + * @deprecated in favor of {@link #defaultTools(Object...)} */ @Deprecated Builder defaultFunctions(FunctionCallback... functionCallbacks); diff --git a/spring-ai-core/src/main/java/org/springframework/ai/chat/model/AbstractToolCallSupport.java b/spring-ai-core/src/main/java/org/springframework/ai/chat/model/AbstractToolCallSupport.java index 01f694c8aa..090ea4a28a 100644 --- a/spring-ai-core/src/main/java/org/springframework/ai/chat/model/AbstractToolCallSupport.java +++ b/spring-ai-core/src/main/java/org/springframework/ai/chat/model/AbstractToolCallSupport.java @@ -1,5 +1,5 @@ /* - * Copyright 2023-2024 the original author or authors. + * Copyright 2023-2025 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -64,10 +64,12 @@ public abstract class AbstractToolCallSupport { */ protected final FunctionCallbackResolver functionCallbackResolver; + @Deprecated protected AbstractToolCallSupport(FunctionCallbackResolver functionCallbackResolver) { this(functionCallbackResolver, FunctionCallingOptions.builder().build(), List.of()); } + @Deprecated protected AbstractToolCallSupport(FunctionCallbackResolver functionCallbackResolver, FunctionCallingOptions functionCallingOptions, List toolFunctionCallbacks) { @@ -97,6 +99,7 @@ private static List merge(FunctionCallingOptions functionOptio return toolFunctionCallbacksCopy; } + @Deprecated public Map getFunctionCallbackRegister() { return this.functionCallbackRegister; } @@ -107,6 +110,7 @@ public Map getFunctionCallbackRegister() { * @param runtimeFunctionOptions FunctionCallingOptions to handle. * @return Set of function names to call. */ + @Deprecated protected Set runtimeFunctionCallbackConfigurations(FunctionCallingOptions runtimeFunctionOptions) { Set enabledFunctionsToCall = new HashSet<>(); @@ -133,6 +137,7 @@ protected Set runtimeFunctionCallbackConfigurations(FunctionCallingOptio return enabledFunctionsToCall; } + @Deprecated protected List handleToolCalls(Prompt prompt, ChatResponse response) { Optional toolCallGeneration = response.getResults() .stream() @@ -165,6 +170,7 @@ protected List handleToolCalls(Prompt prompt, ChatResponse response) { return toolConversationHistory; } + @Deprecated protected List buildToolCallConversation(List previousMessages, AssistantMessage assistantMessage, ToolResponseMessage toolResponseMessage) { List messages = new ArrayList<>(previousMessages); @@ -179,6 +185,7 @@ protected List buildToolCallConversation(List previousMessages * @param functionNames Name of function callbacks to retrieve. * @return list of resolved FunctionCallbacks. */ + @Deprecated protected List resolveFunctionCallbacks(Set functionNames) { List retrievedFunctionCallbacks = new ArrayList<>(); @@ -208,6 +215,7 @@ protected List resolveFunctionCallbacks(Set functionNa return retrievedFunctionCallbacks; } + @Deprecated protected ToolResponseMessage executeFunctions(AssistantMessage assistantMessage, ToolContext toolContext) { List toolResponses = new ArrayList<>(); @@ -230,6 +238,7 @@ protected ToolResponseMessage executeFunctions(AssistantMessage assistantMessage return new ToolResponseMessage(toolResponses, Map.of()); } + @Deprecated protected boolean isToolCall(ChatResponse chatResponse, Set toolCallFinishReasons) { Assert.isTrue(!CollectionUtils.isEmpty(toolCallFinishReasons), "Tool call finish reasons cannot be empty!"); @@ -252,6 +261,7 @@ protected boolean isToolCall(ChatResponse chatResponse, Set toolCallFini * @param toolCallFinishReasons the tool call finish reasons to check. * @return true if the generation is a tool call, false otherwise. */ + @Deprecated protected boolean isToolCall(Generation generation, Set toolCallFinishReasons) { var finishReason = (generation.getMetadata().getFinishReason() != null) ? generation.getMetadata().getFinishReason() : ""; @@ -271,6 +281,7 @@ protected boolean isToolCall(Generation generation, Set toolCallFinishRe * @param defaultOptions the default tool call options to check. * @return true if the proxyToolCalls is enabled, false otherwise. */ + @Deprecated protected boolean isProxyToolCalls(Prompt prompt, FunctionCallingOptions defaultOptions) { if (prompt.getOptions() instanceof FunctionCallingOptions functionCallOptions && functionCallOptions.getProxyToolCalls() != null) { diff --git a/spring-ai-core/src/main/java/org/springframework/ai/chat/model/ToolContext.java b/spring-ai-core/src/main/java/org/springframework/ai/chat/model/ToolContext.java index 69b349383c..51df663c70 100644 --- a/spring-ai-core/src/main/java/org/springframework/ai/chat/model/ToolContext.java +++ b/spring-ai-core/src/main/java/org/springframework/ai/chat/model/ToolContext.java @@ -1,5 +1,5 @@ /* - * Copyright 2023-2024 the original author or authors. + * Copyright 2023-2025 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -71,7 +71,8 @@ public Map getContext() { /** * Returns the tool call history from the context map. - * @return The tool call history. + * @return The tool call history. TODO: review whether we still need this or + * ToolCallingManager solves the original issue */ @SuppressWarnings("unchecked") public List getToolCallHistory() { diff --git a/spring-ai-core/src/main/java/org/springframework/ai/chat/prompt/Prompt.java b/spring-ai-core/src/main/java/org/springframework/ai/chat/prompt/Prompt.java index 3140901c05..a5620b5e7e 100644 --- a/spring-ai-core/src/main/java/org/springframework/ai/chat/prompt/Prompt.java +++ b/spring-ai-core/src/main/java/org/springframework/ai/chat/prompt/Prompt.java @@ -1,5 +1,5 @@ /* - * Copyright 2023-2024 the original author or authors. + * Copyright 2023-2025 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -29,6 +29,7 @@ import org.springframework.ai.chat.messages.ToolResponseMessage; import org.springframework.ai.chat.messages.UserMessage; import org.springframework.ai.model.ModelRequest; +import org.springframework.lang.Nullable; /** * The Prompt class represents a prompt used in AI model requests. A prompt consists of @@ -36,11 +37,13 @@ * * @author Mark Pollack * @author luocongqiu + * @author Thomas Vitale */ public class Prompt implements ModelRequest> { private final List messages; + @Nullable private ChatOptions chatOptions; public Prompt(String contents) { @@ -81,6 +84,7 @@ public String getContents() { } @Override + @Nullable public ChatOptions getOptions() { return this.chatOptions; } diff --git a/spring-ai-core/src/main/java/org/springframework/ai/model/tool/DefaultToolCallingManager.java b/spring-ai-core/src/main/java/org/springframework/ai/model/tool/DefaultToolCallingManager.java index f350b7c85e..a266f376f9 100644 --- a/spring-ai-core/src/main/java/org/springframework/ai/model/tool/DefaultToolCallingManager.java +++ b/spring-ai-core/src/main/java/org/springframework/ai/model/tool/DefaultToolCallingManager.java @@ -17,6 +17,8 @@ package org.springframework.ai.model.tool; import io.micrometer.observation.ObservationRegistry; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; import org.springframework.ai.chat.messages.AssistantMessage; import org.springframework.ai.chat.messages.Message; import org.springframework.ai.chat.messages.ToolResponseMessage; @@ -50,6 +52,8 @@ */ public class DefaultToolCallingManager implements ToolCallingManager { + private static final Logger logger = LoggerFactory.getLogger(DefaultToolCallingManager.class); + // @formatter:off private static final ObservationRegistry DEFAULT_OBSERVATION_REGISTRY @@ -86,7 +90,7 @@ public List resolveToolDefinitions(ToolCallingChatOptions chatOp List toolCallbacks = new ArrayList<>(chatOptions.getToolCallbacks()); for (String toolName : chatOptions.getTools()) { - ToolCallback toolCallback = toolCallbackResolver.resolve(toolName); + FunctionCallback toolCallback = toolCallbackResolver.resolve(toolName); if (toolCallback == null) { throw new IllegalStateException("No ToolCallback found for tool name: " + toolName); } @@ -176,13 +180,15 @@ else if (prompt.getOptions() instanceof FunctionCallingOptions functionOptions) for (AssistantMessage.ToolCall toolCall : assistantMessage.getToolCalls()) { + logger.debug("Executing tool call: {}", toolCall.name()); + String toolName = toolCall.name(); String toolInputArguments = toolCall.arguments(); FunctionCallback toolCallback = toolCallbacks.stream() .filter(tool -> toolName.equals(tool.getName())) .findFirst() - .orElse(toolCallbackResolver.resolve(toolName)); + .orElseGet(() -> toolCallbackResolver.resolve(toolName)); if (toolCallback == null) { throw new IllegalStateException("No ToolCallback found for tool name: " + toolName); diff --git a/spring-ai-core/src/main/java/org/springframework/ai/model/tool/LegacyToolCallingManager.java b/spring-ai-core/src/main/java/org/springframework/ai/model/tool/LegacyToolCallingManager.java new file mode 100644 index 0000000000..ae12fb05d3 --- /dev/null +++ b/spring-ai-core/src/main/java/org/springframework/ai/model/tool/LegacyToolCallingManager.java @@ -0,0 +1,241 @@ +/* + * Copyright 2023-2025 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.ai.model.tool; + +import org.springframework.ai.chat.messages.AssistantMessage; +import org.springframework.ai.chat.messages.Message; +import org.springframework.ai.chat.messages.ToolResponseMessage; +import org.springframework.ai.chat.model.AbstractToolCallSupport; +import org.springframework.ai.chat.model.ChatResponse; +import org.springframework.ai.chat.model.Generation; +import org.springframework.ai.chat.model.ToolContext; +import org.springframework.ai.chat.prompt.Prompt; +import org.springframework.ai.model.function.FunctionCallback; +import org.springframework.ai.model.function.FunctionCallbackResolver; +import org.springframework.ai.model.function.FunctionCallingOptions; +import org.springframework.ai.tool.ToolCallback; +import org.springframework.ai.tool.definition.ToolDefinition; +import org.springframework.ai.tool.execution.DefaultToolCallExceptionConverter; +import org.springframework.ai.tool.execution.ToolCallExceptionConverter; +import org.springframework.ai.tool.execution.ToolExecutionException; +import org.springframework.lang.Nullable; +import org.springframework.util.Assert; +import org.springframework.util.CollectionUtils; + +import java.util.ArrayList; +import java.util.HashMap; +import java.util.List; +import java.util.Map; +import java.util.Optional; + +/** + * Implementation of {@link ToolCallingManager} supporting the migration from + * {@link AbstractToolCallSupport} to {@link ToolCallingManager} and ensuring AI + * compatibility for all the ChatModel implementations. + * + * @author Thomas Vitale + * @since 1.0.0 + * @deprecated Only to help moving away from {@link AbstractToolCallSupport}. It will be + * removed in the next milestone. + */ +@Deprecated +public class LegacyToolCallingManager implements ToolCallingManager { + + private final FunctionCallbackResolver functionCallbackResolver; + + private final Map functionCallbacks = new HashMap<>(); + + private final ToolCallExceptionConverter toolCallExceptionConverter = DefaultToolCallExceptionConverter.builder() + .build(); + + public LegacyToolCallingManager(@Nullable FunctionCallbackResolver functionCallbackResolver, + List functionCallbacks) { + Assert.notNull(functionCallbacks, "functionCallbacks cannot be null"); + Assert.noNullElements(functionCallbacks.toArray(), "functionCallbacks cannot contain null elements"); + this.functionCallbackResolver = functionCallbackResolver; + functionCallbacks.forEach(toolCallback -> this.functionCallbacks.put(toolCallback.getName(), toolCallback)); + } + + @Override + public List resolveToolDefinitions(ToolCallingChatOptions chatOptions) { + Assert.notNull(chatOptions, "chatOptions cannot be null"); + + List toolCallbacks = new ArrayList<>(chatOptions.getToolCallbacks()); + for (String toolName : chatOptions.getTools()) { + FunctionCallback toolCallback = resolveFunctionCallback(toolName); + if (toolCallback == null) { + throw new IllegalStateException("No ToolCallback found for tool name: " + toolName); + } + toolCallbacks.add(toolCallback); + } + + return toolCallbacks.stream().map(functionCallback -> { + if (functionCallback instanceof ToolCallback toolCallback) { + return toolCallback.getToolDefinition(); + } + else { + return ToolDefinition.builder() + .name(functionCallback.getName()) + .description(functionCallback.getDescription()) + .inputSchema(functionCallback.getInputTypeSchema()) + .build(); + } + }).toList(); + } + + @Nullable + private FunctionCallback resolveFunctionCallback(String toolName) { + Assert.hasText(toolName, "toolName cannot be null or empty"); + if (functionCallbacks.get(toolName) != null) { + return functionCallbacks.get(toolName); + } + return functionCallbackResolver != null ? functionCallbackResolver.resolve(toolName) : null; + } + + @Override + public List executeToolCalls(Prompt prompt, ChatResponse chatResponse) { + Assert.notNull(prompt, "prompt cannot be null"); + Assert.notNull(chatResponse, "chatResponse cannot be null"); + + Optional toolCallGeneration = chatResponse.getResults() + .stream() + .filter(g -> !CollectionUtils.isEmpty(g.getOutput().getToolCalls())) + .findFirst(); + + if (toolCallGeneration.isEmpty()) { + throw new IllegalStateException("No tool call requested by the chat model"); + } + + AssistantMessage assistantMessage = toolCallGeneration.get().getOutput(); + + ToolContext toolContext = buildToolContext(prompt, assistantMessage); + + ToolResponseMessage toolMessageResponse = executeToolCall(prompt, assistantMessage, toolContext); + + return buildConversationHistoryAfterToolExecution(prompt.getInstructions(), assistantMessage, + toolMessageResponse); + } + + private static ToolContext buildToolContext(Prompt prompt, AssistantMessage assistantMessage) { + Map toolContextMap = Map.of(); + + if (prompt.getOptions() instanceof FunctionCallingOptions functionOptions + && !CollectionUtils.isEmpty(functionOptions.getToolContext())) { + toolContextMap = new HashMap<>(functionOptions.getToolContext()); + + List messageHistory = new ArrayList<>(prompt.copy().getInstructions()); + messageHistory.add(new AssistantMessage(assistantMessage.getText(), assistantMessage.getMetadata(), + assistantMessage.getToolCalls())); + + toolContextMap.put(ToolContext.TOOL_CALL_HISTORY, + buildConversationHistoryBeforeToolExecution(prompt, assistantMessage)); + } + + return new ToolContext(toolContextMap); + } + + private static List buildConversationHistoryBeforeToolExecution(Prompt prompt, + AssistantMessage assistantMessage) { + List messageHistory = new ArrayList<>(prompt.copy().getInstructions()); + messageHistory.add(new AssistantMessage(assistantMessage.getText(), assistantMessage.getMetadata(), + assistantMessage.getToolCalls())); + return messageHistory; + } + + /** + * Execute the tool call and return the response message. To ensure backward + * compatibility, both {@link ToolCallback} and {@link FunctionCallback} are + * supported. + */ + private ToolResponseMessage executeToolCall(Prompt prompt, AssistantMessage assistantMessage, + ToolContext toolContext) { + List toolCallbacks = List.of(); + if (prompt.getOptions() instanceof ToolCallingChatOptions toolCallingChatOptions) { + toolCallbacks = toolCallingChatOptions.getToolCallbacks(); + } + else if (prompt.getOptions() instanceof FunctionCallingOptions functionOptions) { + toolCallbacks = functionOptions.getFunctionCallbacks(); + } + + List toolResponses = new ArrayList<>(); + + for (AssistantMessage.ToolCall toolCall : assistantMessage.getToolCalls()) { + + String toolName = toolCall.name(); + String toolInputArguments = toolCall.arguments(); + + FunctionCallback toolCallback = toolCallbacks.stream() + .filter(tool -> toolName.equals(tool.getName())) + .findFirst() + .orElseGet(() -> resolveFunctionCallback(toolName)); + + if (toolCallback == null) { + throw new IllegalStateException("No ToolCallback found for tool name: " + toolName); + } + + String toolResult; + try { + toolResult = toolCallback.call(toolInputArguments, toolContext); + } + catch (ToolExecutionException ex) { + toolResult = toolCallExceptionConverter.convert(ex); + } + + toolResponses.add(new ToolResponseMessage.ToolResponse(toolCall.id(), toolName, toolResult)); + } + + return new ToolResponseMessage(toolResponses, Map.of()); + } + + private List buildConversationHistoryAfterToolExecution(List previousMessages, + AssistantMessage assistantMessage, ToolResponseMessage toolResponseMessage) { + List messages = new ArrayList<>(previousMessages); + messages.add(assistantMessage); + messages.add(toolResponseMessage); + return messages; + } + + public static Builder builder() { + return new Builder(); + } + + public static class Builder { + + private FunctionCallbackResolver functionCallbackResolver; + + private List functionCallbacks = new ArrayList<>(); + + private Builder() { + } + + public Builder functionCallbackResolver(FunctionCallbackResolver functionCallbackResolver) { + this.functionCallbackResolver = functionCallbackResolver; + return this; + } + + public Builder functionCallbacks(List functionCallbacks) { + this.functionCallbacks = functionCallbacks; + return this; + } + + public LegacyToolCallingManager build() { + return new LegacyToolCallingManager(functionCallbackResolver, functionCallbacks); + } + + } + +} diff --git a/spring-ai-core/src/main/java/org/springframework/ai/model/tool/ToolCallingChatOptions.java b/spring-ai-core/src/main/java/org/springframework/ai/model/tool/ToolCallingChatOptions.java index 767cab146b..32aed320ef 100644 --- a/spring-ai-core/src/main/java/org/springframework/ai/model/tool/ToolCallingChatOptions.java +++ b/spring-ai-core/src/main/java/org/springframework/ai/model/tool/ToolCallingChatOptions.java @@ -23,6 +23,8 @@ import org.springframework.lang.Nullable; import org.springframework.util.Assert; +import java.util.ArrayList; +import java.util.HashSet; import java.util.List; import java.util.Map; import java.util.Set; @@ -185,4 +187,21 @@ else if (chatOptions instanceof FunctionCallingOptions functionCallingOptions return internalToolExecutionEnabled; } + static Set mergeToolNames(Set runtimeToolNames, Set defaultToolNames) { + Assert.notNull(runtimeToolNames, "runtimeToolNames cannot be null"); + Assert.notNull(defaultToolNames, "defaultToolNames cannot be null"); + var mergedToolNames = new HashSet<>(runtimeToolNames); + mergedToolNames.addAll(defaultToolNames); + return mergedToolNames; + } + + static List mergeToolCallbacks(List runtimeToolCallbacks, + List defaultToolCallbacks) { + Assert.notNull(runtimeToolCallbacks, "runtimeToolCallbacks cannot be null"); + Assert.notNull(defaultToolCallbacks, "defaultToolCallbacks cannot be null"); + var mergedToolCallbacks = new ArrayList<>(runtimeToolCallbacks); + mergedToolCallbacks.addAll(defaultToolCallbacks); + return mergedToolCallbacks; + } + } diff --git a/spring-ai-core/src/main/java/org/springframework/ai/tool/definition/ToolDefinition.java b/spring-ai-core/src/main/java/org/springframework/ai/tool/definition/ToolDefinition.java index bee3ec03cb..699c69acbe 100644 --- a/spring-ai-core/src/main/java/org/springframework/ai/tool/definition/ToolDefinition.java +++ b/spring-ai-core/src/main/java/org/springframework/ai/tool/definition/ToolDefinition.java @@ -18,6 +18,7 @@ import org.springframework.ai.tool.util.ToolUtils; import org.springframework.ai.util.json.JsonSchemaGenerator; +import org.springframework.util.Assert; import java.lang.reflect.Method; @@ -55,6 +56,7 @@ static DefaultToolDefinition.Builder builder() { * Create a default {@link ToolDefinition} builder from a {@link Method}. */ static DefaultToolDefinition.Builder builder(Method method) { + Assert.notNull(method, "method cannot be null"); return DefaultToolDefinition.builder() .name(ToolUtils.getToolName(method)) .description(ToolUtils.getToolDescription(method)) diff --git a/spring-ai-core/src/main/java/org/springframework/ai/tool/execution/DefaultToolCallExceptionConverter.java b/spring-ai-core/src/main/java/org/springframework/ai/tool/execution/DefaultToolCallExceptionConverter.java index 30ec1947e8..97ab2895fa 100644 --- a/spring-ai-core/src/main/java/org/springframework/ai/tool/execution/DefaultToolCallExceptionConverter.java +++ b/spring-ai-core/src/main/java/org/springframework/ai/tool/execution/DefaultToolCallExceptionConverter.java @@ -16,6 +16,8 @@ package org.springframework.ai.tool.execution; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; import org.springframework.util.Assert; /** @@ -26,6 +28,8 @@ */ public class DefaultToolCallExceptionConverter implements ToolCallExceptionConverter { + private final static Logger logger = LoggerFactory.getLogger(DefaultToolCallExceptionConverter.class); + private static final boolean DEFAULT_ALWAYS_THROW = false; private final boolean alwaysThrow; @@ -40,6 +44,8 @@ public String convert(ToolExecutionException exception) { if (alwaysThrow) { throw exception; } + logger.debug("Exception thrown by tool: {}. Message: {}", exception.getToolDefinition().name(), + exception.getMessage()); return exception.getMessage(); } diff --git a/spring-ai-core/src/main/java/org/springframework/ai/tool/execution/DefaultToolCallResultConverter.java b/spring-ai-core/src/main/java/org/springframework/ai/tool/execution/DefaultToolCallResultConverter.java index 577acbae9c..290846c6f3 100644 --- a/spring-ai-core/src/main/java/org/springframework/ai/tool/execution/DefaultToolCallResultConverter.java +++ b/spring-ai-core/src/main/java/org/springframework/ai/tool/execution/DefaultToolCallResultConverter.java @@ -41,6 +41,7 @@ public String apply(@Nullable Object result, @Nullable Type returnType) { return "Done"; } else { + logger.debug("Converting tool result to JSON."); return JsonParser.toJson(result); } } diff --git a/spring-ai-core/src/main/java/org/springframework/ai/tool/function/FunctionToolCallback.java b/spring-ai-core/src/main/java/org/springframework/ai/tool/function/FunctionToolCallback.java index 6d5936b4af..8d635a6f47 100644 --- a/spring-ai-core/src/main/java/org/springframework/ai/tool/function/FunctionToolCallback.java +++ b/spring-ai-core/src/main/java/org/springframework/ai/tool/function/FunctionToolCallback.java @@ -105,6 +105,11 @@ public String call(String toolInput, @Nullable ToolContext toolContext) { return toolCallResultConverter.apply(response, null); } + @Override + public String toString() { + return "FunctionToolCallback{" + "toolDefinition=" + toolDefinition + ", toolMetadata=" + toolMetadata + '}'; + } + /** * Build a {@link FunctionToolCallback} from a {@link BiFunction}. */ diff --git a/spring-ai-core/src/main/java/org/springframework/ai/tool/metadata/ToolMetadata.java b/spring-ai-core/src/main/java/org/springframework/ai/tool/metadata/ToolMetadata.java index dda2beb7c8..63d785f118 100644 --- a/spring-ai-core/src/main/java/org/springframework/ai/tool/metadata/ToolMetadata.java +++ b/spring-ai-core/src/main/java/org/springframework/ai/tool/metadata/ToolMetadata.java @@ -17,6 +17,7 @@ package org.springframework.ai.tool.metadata; import org.springframework.ai.tool.util.ToolUtils; +import org.springframework.util.Assert; import java.lang.reflect.Method; @@ -46,6 +47,7 @@ static DefaultToolMetadata.Builder builder() { * Create a default {@link ToolMetadata} instance from a {@link Method}. */ static ToolMetadata from(Method method) { + Assert.notNull(method, "method cannot be null"); return DefaultToolMetadata.builder().returnDirect(ToolUtils.getToolReturnDirect(method)).build(); } diff --git a/spring-ai-core/src/main/java/org/springframework/ai/tool/method/MethodToolCallback.java b/spring-ai-core/src/main/java/org/springframework/ai/tool/method/MethodToolCallback.java index c4b153e02a..a2ddf71947 100644 --- a/spring-ai-core/src/main/java/org/springframework/ai/tool/method/MethodToolCallback.java +++ b/spring-ai-core/src/main/java/org/springframework/ai/tool/method/MethodToolCallback.java @@ -60,6 +60,7 @@ public class MethodToolCallback implements ToolCallback { private final Method toolMethod; + @Nullable private final Object toolObject; private final ToolCallResultConverter toolCallResultConverter; @@ -174,6 +175,11 @@ private boolean isMethodNotPublic() { return !Modifier.isPublic(toolMethod.getModifiers()); } + @Override + public String toString() { + return "MethodToolCallback{" + "toolDefinition=" + toolDefinition + ", toolMetadata=" + toolMetadata + '}'; + } + public static Builder builder() { return new Builder(); } diff --git a/spring-ai-core/src/main/java/org/springframework/ai/tool/resolution/DelegatingToolCallbackResolver.java b/spring-ai-core/src/main/java/org/springframework/ai/tool/resolution/DelegatingToolCallbackResolver.java index afbf3e4823..16688b0536 100644 --- a/spring-ai-core/src/main/java/org/springframework/ai/tool/resolution/DelegatingToolCallbackResolver.java +++ b/spring-ai-core/src/main/java/org/springframework/ai/tool/resolution/DelegatingToolCallbackResolver.java @@ -16,7 +16,7 @@ package org.springframework.ai.tool.resolution; -import org.springframework.ai.tool.ToolCallback; +import org.springframework.ai.model.function.FunctionCallback; import org.springframework.lang.Nullable; import org.springframework.util.Assert; @@ -41,9 +41,11 @@ public DelegatingToolCallbackResolver(List toolCallbackRes @Override @Nullable - public ToolCallback resolve(String toolName) { + public FunctionCallback resolve(String toolName) { + Assert.hasText(toolName, "toolName cannot be null or empty"); + for (ToolCallbackResolver toolCallbackResolver : toolCallbackResolvers) { - ToolCallback toolCallback = toolCallbackResolver.resolve(toolName); + FunctionCallback toolCallback = toolCallbackResolver.resolve(toolName); if (toolCallback != null) { return toolCallback; } diff --git a/spring-ai-core/src/main/java/org/springframework/ai/tool/resolution/SpringBeanToolCallbackResolver.java b/spring-ai-core/src/main/java/org/springframework/ai/tool/resolution/SpringBeanToolCallbackResolver.java index e5c5137902..5465ee8362 100644 --- a/spring-ai-core/src/main/java/org/springframework/ai/tool/resolution/SpringBeanToolCallbackResolver.java +++ b/spring-ai-core/src/main/java/org/springframework/ai/tool/resolution/SpringBeanToolCallbackResolver.java @@ -20,6 +20,8 @@ import kotlin.jvm.functions.Function0; import kotlin.jvm.functions.Function1; import kotlin.jvm.functions.Function2; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; import org.springframework.ai.chat.model.ToolContext; import org.springframework.ai.tool.ToolCallback; import org.springframework.ai.tool.function.FunctionToolCallback; @@ -55,6 +57,8 @@ */ public class SpringBeanToolCallbackResolver implements ToolCallbackResolver { + private static final Logger logger = LoggerFactory.getLogger(SpringBeanToolCallbackResolver.class); + private static final Map toolCallbacksCache = new HashMap<>(); private static final SchemaType DEFAULT_SCHEMA_TYPE = SchemaType.JSON_SCHEMA; @@ -75,6 +79,8 @@ public SpringBeanToolCallbackResolver(GenericApplicationContext applicationConte public ToolCallback resolve(String toolName) { Assert.hasText(toolName, "toolName cannot be null or empty"); + logger.debug("ToolCallback resolution attempt from Spring application context"); + ToolCallback resolvedToolCallback = toolCallbacksCache.get(toolName); if (resolvedToolCallback != null) { diff --git a/spring-ai-core/src/main/java/org/springframework/ai/tool/resolution/StaticToolCallbackResolver.java b/spring-ai-core/src/main/java/org/springframework/ai/tool/resolution/StaticToolCallbackResolver.java index 24d0d14b32..4e7352edb7 100644 --- a/spring-ai-core/src/main/java/org/springframework/ai/tool/resolution/StaticToolCallbackResolver.java +++ b/spring-ai-core/src/main/java/org/springframework/ai/tool/resolution/StaticToolCallbackResolver.java @@ -16,6 +16,9 @@ package org.springframework.ai.tool.resolution; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; +import org.springframework.ai.model.function.FunctionCallback; import org.springframework.ai.tool.ToolCallback; import org.springframework.util.Assert; @@ -31,18 +34,26 @@ */ public class StaticToolCallbackResolver implements ToolCallbackResolver { - private final Map toolCallbacks = new HashMap<>(); + private static final Logger logger = LoggerFactory.getLogger(StaticToolCallbackResolver.class); - public StaticToolCallbackResolver(List toolCallbacks) { + private final Map toolCallbacks = new HashMap<>(); + + public StaticToolCallbackResolver(List toolCallbacks) { Assert.notNull(toolCallbacks, "toolCallbacks cannot be null"); Assert.noNullElements(toolCallbacks, "toolCallbacks cannot contain null elements"); - toolCallbacks - .forEach(toolCallback -> this.toolCallbacks.put(toolCallback.getToolDefinition().name(), toolCallback)); + toolCallbacks.forEach(callback -> { + if (callback instanceof ToolCallback toolCallback) { + this.toolCallbacks.put(toolCallback.getToolDefinition().name(), toolCallback); + } + this.toolCallbacks.put(callback.getName(), callback); + }); } @Override - public ToolCallback resolve(String toolName) { + public FunctionCallback resolve(String toolName) { + Assert.hasText(toolName, "toolName cannot be null or empty"); + logger.debug("ToolCallback resolution attempt from static registry"); return toolCallbacks.get(toolName); } diff --git a/spring-ai-core/src/main/java/org/springframework/ai/tool/resolution/ToolCallbackResolver.java b/spring-ai-core/src/main/java/org/springframework/ai/tool/resolution/ToolCallbackResolver.java index 8efa01e9cc..1155e4042e 100644 --- a/spring-ai-core/src/main/java/org/springframework/ai/tool/resolution/ToolCallbackResolver.java +++ b/spring-ai-core/src/main/java/org/springframework/ai/tool/resolution/ToolCallbackResolver.java @@ -16,6 +16,7 @@ package org.springframework.ai.tool.resolution; +import org.springframework.ai.model.function.FunctionCallback; import org.springframework.ai.tool.ToolCallback; import org.springframework.lang.Nullable; @@ -28,9 +29,9 @@ public interface ToolCallbackResolver { /** - * Resolve the {@link ToolCallback} for the given tool name. + * Resolve the {@link FunctionCallback} for the given tool name. */ @Nullable - ToolCallback resolve(String toolName); + FunctionCallback resolve(String toolName); } diff --git a/spring-ai-core/src/main/java/org/springframework/ai/tool/util/ToolUtils.java b/spring-ai-core/src/main/java/org/springframework/ai/tool/util/ToolUtils.java index 786d97e92f..bbbfe50922 100644 --- a/spring-ai-core/src/main/java/org/springframework/ai/tool/util/ToolUtils.java +++ b/spring-ai-core/src/main/java/org/springframework/ai/tool/util/ToolUtils.java @@ -42,6 +42,7 @@ private ToolUtils() { } public static String getToolName(Method method) { + Assert.notNull(method, "method cannot be null"); var tool = method.getAnnotation(Tool.class); if (tool == null) { return method.getName(); @@ -49,12 +50,13 @@ public static String getToolName(Method method) { return StringUtils.hasText(tool.name()) ? tool.name() : method.getName(); } - public static String getToolDescriptionFromName(@Nullable String toolName) { + public static String getToolDescriptionFromName(String toolName) { Assert.hasText(toolName, "toolName cannot be null or empty"); return ParsingUtils.reConcatenateCamelCase(toolName, " "); } public static String getToolDescription(Method method) { + Assert.notNull(method, "method cannot be null"); var tool = method.getAnnotation(Tool.class); if (tool == null) { return ParsingUtils.reConcatenateCamelCase(method.getName(), " "); @@ -63,11 +65,13 @@ public static String getToolDescription(Method method) { } public static boolean getToolReturnDirect(Method method) { + Assert.notNull(method, "method cannot be null"); var tool = method.getAnnotation(Tool.class); return tool != null && tool.returnDirect(); } public static ToolCallResultConverter getToolCallResultConverter(Method method) { + Assert.notNull(method, "method cannot be null"); var tool = method.getAnnotation(Tool.class); if (tool == null) { return new DefaultToolCallResultConverter(); @@ -81,8 +85,9 @@ public static ToolCallResultConverter getToolCallResultConverter(Method method) } } - public static List getDuplicateToolNames(FunctionCallback... functionCallbacks) { - return Stream.of(functionCallbacks) + public static List getDuplicateToolNames(FunctionCallback... toolCallbacks) { + Assert.notNull(toolCallbacks, "toolCallbacks cannot be null"); + return Stream.of(toolCallbacks) .collect(Collectors.groupingBy(FunctionCallback::getName, Collectors.counting())) .entrySet() .stream() diff --git a/spring-ai-core/src/test/java/org/springframework/ai/model/tool/LegacyToolCallingManagerTests.java b/spring-ai-core/src/test/java/org/springframework/ai/model/tool/LegacyToolCallingManagerTests.java new file mode 100644 index 0000000000..7e6adffaca --- /dev/null +++ b/spring-ai-core/src/test/java/org/springframework/ai/model/tool/LegacyToolCallingManagerTests.java @@ -0,0 +1,211 @@ +/* + * Copyright 2023-2025 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.ai.model.tool; + +import org.junit.jupiter.api.Test; +import org.springframework.ai.chat.messages.AssistantMessage; +import org.springframework.ai.chat.messages.Message; +import org.springframework.ai.chat.messages.ToolResponseMessage; +import org.springframework.ai.chat.messages.UserMessage; +import org.springframework.ai.chat.model.ChatResponse; +import org.springframework.ai.chat.model.Generation; +import org.springframework.ai.chat.prompt.Prompt; +import org.springframework.ai.tool.ToolCallback; +import org.springframework.ai.tool.definition.ToolDefinition; +import org.springframework.ai.tool.execution.ToolExecutionException; + +import java.util.List; +import java.util.Map; + +import static org.assertj.core.api.Assertions.assertThat; +import static org.assertj.core.api.Assertions.assertThatThrownBy; +import static org.mockito.Mockito.mock; + +/** + * Unit tests for {@link LegacyToolCallingManager}. + * + * @author Thomas Vitale + */ +class LegacyToolCallingManagerTests { + + // RESOLVE TOOL DEFINITIONS + + @Test + void whenChatOptionsIsNullThenThrow() { + ToolCallingManager toolCallingManager = LegacyToolCallingManager.builder().build(); + assertThatThrownBy(() -> toolCallingManager.resolveToolDefinitions(null)) + .isInstanceOf(IllegalArgumentException.class) + .hasMessage("chatOptions cannot be null"); + } + + @Test + void whenToolCallbackExistsThenResolve() { + ToolCallback toolCallback = new TestToolCallback("toolA"); + ToolCallingManager toolCallingManager = LegacyToolCallingManager.builder() + .functionCallbacks(List.of(toolCallback)) + .build(); + + List toolDefinitions = toolCallingManager + .resolveToolDefinitions(ToolCallingChatOptions.builder().tools("toolA").build()); + + assertThat(toolDefinitions).containsExactly(toolCallback.getToolDefinition()); + } + + @Test + void whenToolCallbackDoesNotExistThenThrow() { + ToolCallingManager toolCallingManager = LegacyToolCallingManager.builder().functionCallbacks(List.of()).build(); + + assertThatThrownBy(() -> toolCallingManager + .resolveToolDefinitions(ToolCallingChatOptions.builder().tools("toolB").build())) + .isInstanceOf(IllegalStateException.class) + .hasMessage("No ToolCallback found for tool name: toolB"); + } + + // EXECUTE TOOL CALLS + + @Test + void whenPromptIsNullThenThrow() { + ToolCallingManager toolCallingManager = LegacyToolCallingManager.builder().build(); + assertThatThrownBy(() -> toolCallingManager.executeToolCalls(null, mock(ChatResponse.class))) + .isInstanceOf(IllegalArgumentException.class) + .hasMessage("prompt cannot be null"); + } + + @Test + void whenChatResponseIsNullThenThrow() { + ToolCallingManager toolCallingManager = LegacyToolCallingManager.builder().build(); + assertThatThrownBy(() -> toolCallingManager.executeToolCalls(mock(Prompt.class), null)) + .isInstanceOf(IllegalArgumentException.class) + .hasMessage("chatResponse cannot be null"); + } + + @Test + void whenNoToolCallInChatResponseThenThrow() { + ToolCallingManager toolCallingManager = LegacyToolCallingManager.builder().build(); + assertThatThrownBy(() -> toolCallingManager.executeToolCalls(mock(Prompt.class), + ChatResponse.builder().generations(List.of()).build())) + .isInstanceOf(IllegalStateException.class) + .hasMessage("No tool call requested by the chat model"); + } + + @Test + void whenSingleToolCallInChatResponseThenExecute() { + ToolCallback toolCallback = new LegacyToolCallingManagerTests.TestToolCallback("toolA"); + ToolCallingManager toolCallingManager = LegacyToolCallingManager.builder() + .functionCallbacks(List.of(toolCallback)) + .build(); + + Prompt prompt = new Prompt(new UserMessage("Hello"), ToolCallingChatOptions.builder().build()); + ChatResponse chatResponse = ChatResponse.builder() + .generations(List.of(new Generation(new AssistantMessage("", Map.of(), + List.of(new AssistantMessage.ToolCall("toolA", "function", "toolA", "{}")))))) + .build(); + + ToolResponseMessage expectedToolResponse = new ToolResponseMessage( + List.of(new ToolResponseMessage.ToolResponse("toolA", "toolA", "Mission accomplished!"))); + + List toolCallHistory = toolCallingManager.executeToolCalls(prompt, chatResponse); + + assertThat(toolCallHistory).contains(expectedToolResponse); + } + + @Test + void whenMultipleToolCallsInChatResponseThenExecute() { + ToolCallback toolCallbackA = new LegacyToolCallingManagerTests.TestToolCallback("toolA"); + ToolCallback toolCallbackB = new LegacyToolCallingManagerTests.TestToolCallback("toolB"); + ToolCallingManager toolCallingManager = LegacyToolCallingManager.builder() + .functionCallbacks(List.of(toolCallbackA, toolCallbackB)) + .build(); + + Prompt prompt = new Prompt(new UserMessage("Hello"), ToolCallingChatOptions.builder().build()); + ChatResponse chatResponse = ChatResponse.builder() + .generations(List.of(new Generation(new AssistantMessage("", Map.of(), + List.of(new AssistantMessage.ToolCall("toolA", "function", "toolA", "{}"), + new AssistantMessage.ToolCall("toolB", "function", "toolB", "{}")))))) + .build(); + + ToolResponseMessage expectedToolResponse = new ToolResponseMessage( + List.of(new ToolResponseMessage.ToolResponse("toolA", "toolA", "Mission accomplished!"), + new ToolResponseMessage.ToolResponse("toolB", "toolB", "Mission accomplished!"))); + + List toolCallHistory = toolCallingManager.executeToolCalls(prompt, chatResponse); + + assertThat(toolCallHistory).contains(expectedToolResponse); + } + + @Test + void whenToolCallWithExceptionThenReturnError() { + ToolCallback toolCallback = new LegacyToolCallingManagerTests.FailingToolCallback("toolC"); + ToolCallingManager toolCallingManager = LegacyToolCallingManager.builder() + .functionCallbacks(List.of(toolCallback)) + .build(); + + Prompt prompt = new Prompt(new UserMessage("Hello"), ToolCallingChatOptions.builder().build()); + ChatResponse chatResponse = ChatResponse.builder() + .generations(List.of(new Generation(new AssistantMessage("", Map.of(), + List.of(new AssistantMessage.ToolCall("toolC", "function", "toolC", "{}")))))) + .build(); + + ToolResponseMessage expectedToolResponse = new ToolResponseMessage( + List.of(new ToolResponseMessage.ToolResponse("toolC", "toolC", "You failed this city!"))); + + List toolCallHistory = toolCallingManager.executeToolCalls(prompt, chatResponse); + + assertThat(toolCallHistory).contains(expectedToolResponse); + } + + static class TestToolCallback implements ToolCallback { + + private final ToolDefinition toolDefinition; + + public TestToolCallback(String name) { + this.toolDefinition = ToolDefinition.builder().name(name).inputSchema("{}").build(); + } + + @Override + public ToolDefinition getToolDefinition() { + return toolDefinition; + } + + @Override + public String call(String toolInput) { + return "Mission accomplished!"; + } + + } + + static class FailingToolCallback implements ToolCallback { + + private final ToolDefinition toolDefinition; + + public FailingToolCallback(String name) { + this.toolDefinition = ToolDefinition.builder().name(name).inputSchema("{}").build(); + } + + @Override + public ToolDefinition getToolDefinition() { + return toolDefinition; + } + + @Override + public String call(String toolInput) { + throw new ToolExecutionException(toolDefinition, new IllegalStateException("You failed this city!")); + } + + } + +} diff --git a/spring-ai-core/src/test/java/org/springframework/ai/model/tool/ToolCallingChatOptionsTests.java b/spring-ai-core/src/test/java/org/springframework/ai/model/tool/ToolCallingChatOptionsTests.java index 7dba59bb6d..c3f92df258 100644 --- a/spring-ai-core/src/test/java/org/springframework/ai/model/tool/ToolCallingChatOptionsTests.java +++ b/spring-ai-core/src/test/java/org/springframework/ai/model/tool/ToolCallingChatOptionsTests.java @@ -16,9 +16,15 @@ package org.springframework.ai.model.tool; import org.junit.jupiter.api.Test; +import org.springframework.ai.model.function.FunctionCallback; import org.springframework.ai.model.function.FunctionCallingOptions; +import org.springframework.ai.tool.ToolCallback; +import org.springframework.ai.tool.definition.ToolDefinition; -import static org.assertj.core.api.AssertionsForClassTypes.assertThat; +import java.util.List; +import java.util.Set; + +import static org.assertj.core.api.Assertions.assertThat; /** * Unit tests for {@link ToolCallingChatOptions}. @@ -67,4 +73,92 @@ void whenFunctionCallingOptionsAndExecutionEnabledDefault() { assertThat(ToolCallingChatOptions.isInternalToolExecutionEnabled(options)).isTrue(); } + @Test + void whenMergeRuntimeAndDefaultToolNames() { + Set runtimeToolNames = Set.of("toolA"); + Set defaultToolNames = Set.of("toolB"); + Set mergedToolNames = ToolCallingChatOptions.mergeToolNames(runtimeToolNames, defaultToolNames); + assertThat(mergedToolNames).containsExactlyInAnyOrder("toolA", "toolB"); + } + + @Test + void whenMergeRuntimeAndEmptyDefaultToolNames() { + Set runtimeToolNames = Set.of("toolA"); + Set defaultToolNames = Set.of(); + Set mergedToolNames = ToolCallingChatOptions.mergeToolNames(runtimeToolNames, defaultToolNames); + assertThat(mergedToolNames).containsExactlyInAnyOrder("toolA"); + } + + @Test + void whenMergeEmptyRuntimeAndDefaultToolNames() { + Set runtimeToolNames = Set.of(); + Set defaultToolNames = Set.of("toolB"); + Set mergedToolNames = ToolCallingChatOptions.mergeToolNames(runtimeToolNames, defaultToolNames); + assertThat(mergedToolNames).containsExactlyInAnyOrder("toolB"); + } + + @Test + void whenMergeEmptyRuntimeAndEmptyDefaultToolNames() { + Set runtimeToolNames = Set.of(); + Set defaultToolNames = Set.of(); + Set mergedToolNames = ToolCallingChatOptions.mergeToolNames(runtimeToolNames, defaultToolNames); + assertThat(mergedToolNames).containsExactlyInAnyOrder(); + } + + @Test + void whenMergeRuntimeAndDefaultToolCallbacks() { + List runtimeToolCallbacks = List.of(new TestToolCallback("toolA")); + List defaultToolCallbacks = List.of(new TestToolCallback("toolB")); + List mergedToolCallbacks = ToolCallingChatOptions.mergeToolCallbacks(runtimeToolCallbacks, + defaultToolCallbacks); + assertThat(mergedToolCallbacks).hasSize(2); + } + + @Test + void whenMergeRuntimeAndEmptyDefaultToolCallbacks() { + List runtimeToolCallbacks = List.of(new TestToolCallback("toolA")); + List defaultToolCallbacks = List.of(); + List mergedToolCallbacks = ToolCallingChatOptions.mergeToolCallbacks(runtimeToolCallbacks, + defaultToolCallbacks); + assertThat(mergedToolCallbacks).hasSize(1); + } + + @Test + void whenMergeEmptyRuntimeAndDefaultToolCallbacks() { + List runtimeToolCallbacks = List.of(); + List defaultToolCallbacks = List.of(new TestToolCallback("toolB")); + List mergedToolCallbacks = ToolCallingChatOptions.mergeToolCallbacks(runtimeToolCallbacks, + defaultToolCallbacks); + assertThat(mergedToolCallbacks).hasSize(1); + } + + @Test + void whenMergeEmptyRuntimeAndEmptyDefaultToolCallbacks() { + List runtimeToolCallbacks = List.of(); + List defaultToolCallbacks = List.of(); + List mergedToolCallbacks = ToolCallingChatOptions.mergeToolCallbacks(runtimeToolCallbacks, + defaultToolCallbacks); + assertThat(mergedToolCallbacks).hasSize(0); + } + + static class TestToolCallback implements ToolCallback { + + private final ToolDefinition toolDefinition; + + public TestToolCallback(String name) { + this.toolDefinition = ToolDefinition.builder().name(name).inputSchema("{}").build(); + } + + @Override + public ToolDefinition getToolDefinition() { + return toolDefinition; + } + + @Override + public String call(String toolInput) { + return "Mission accomplished!"; + } + + } + } diff --git a/spring-ai-core/src/test/java/org/springframework/ai/tool/resolution/StaticToolCallbackResolverTests.java b/spring-ai-core/src/test/java/org/springframework/ai/tool/resolution/StaticToolCallbackResolverTests.java index c7398a9b13..2c20b74dd9 100644 --- a/spring-ai-core/src/test/java/org/springframework/ai/tool/resolution/StaticToolCallbackResolverTests.java +++ b/spring-ai-core/src/test/java/org/springframework/ai/tool/resolution/StaticToolCallbackResolverTests.java @@ -17,6 +17,7 @@ package org.springframework.ai.tool.resolution; import org.junit.jupiter.api.Test; +import org.springframework.ai.model.function.FunctionCallback; import org.springframework.ai.tool.ToolCallback; import org.springframework.ai.tool.definition.ToolDefinition; @@ -42,7 +43,7 @@ void whenToolCallbacksAreNullThenThrowException() { @Test void whenToolCallbacksContainNullElementsThenThrowException() { - var toolCallbacks = new ArrayList(); + var toolCallbacks = new ArrayList(); toolCallbacks.add(null); assertThatThrownBy(() -> new StaticToolCallbackResolver(toolCallbacks)) .isInstanceOf(IllegalArgumentException.class); diff --git a/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/chat/model/ToolCallingAutoConfiguration.java b/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/chat/model/ToolCallingAutoConfiguration.java new file mode 100644 index 0000000000..18804a9b1a --- /dev/null +++ b/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/chat/model/ToolCallingAutoConfiguration.java @@ -0,0 +1,78 @@ +/* + * Copyright 2023-2025 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.ai.autoconfigure.chat.model; + +import io.micrometer.observation.ObservationRegistry; +import org.springframework.ai.chat.model.ChatModel; +import org.springframework.ai.model.function.FunctionCallback; +import org.springframework.ai.model.tool.ToolCallingManager; +import org.springframework.ai.tool.execution.DefaultToolCallExceptionConverter; +import org.springframework.ai.tool.execution.ToolCallExceptionConverter; +import org.springframework.ai.tool.resolution.DelegatingToolCallbackResolver; +import org.springframework.ai.tool.resolution.SpringBeanToolCallbackResolver; +import org.springframework.ai.tool.resolution.StaticToolCallbackResolver; +import org.springframework.ai.tool.resolution.ToolCallbackResolver; +import org.springframework.beans.factory.ObjectProvider; +import org.springframework.boot.autoconfigure.AutoConfiguration; +import org.springframework.boot.autoconfigure.condition.ConditionalOnClass; +import org.springframework.boot.autoconfigure.condition.ConditionalOnMissingBean; +import org.springframework.context.annotation.Bean; +import org.springframework.context.support.GenericApplicationContext; + +import java.util.List; + +/** + * Auto-configuration for common tool calling features of {@link ChatModel}. + * + * @author Thomas Vitale + * @since 1.0.0 + */ +@AutoConfiguration +@ConditionalOnClass(ChatModel.class) +public class ToolCallingAutoConfiguration { + + @Bean + @ConditionalOnMissingBean + ToolCallbackResolver toolCallbackResolver(GenericApplicationContext applicationContext, + List toolCallbacks) { + var staticToolCallbackResolver = new StaticToolCallbackResolver(toolCallbacks); + var springBeanToolCallbackResolver = SpringBeanToolCallbackResolver.builder() + .applicationContext(applicationContext) + .build(); + + return new DelegatingToolCallbackResolver(List.of(staticToolCallbackResolver, springBeanToolCallbackResolver)); + } + + @Bean + @ConditionalOnMissingBean + ToolCallExceptionConverter toolCallExceptionConverter() { + return new DefaultToolCallExceptionConverter(false); + } + + @Bean + @ConditionalOnMissingBean + ToolCallingManager toolCallingManager(ToolCallbackResolver toolCallbackResolver, + ToolCallExceptionConverter toolCallExceptionConverter, + ObjectProvider observationRegistry) { + return ToolCallingManager.builder() + .observationRegistry(observationRegistry.getIfUnique(() -> ObservationRegistry.NOOP)) + .toolCallbackResolver(toolCallbackResolver) + .toolCallExceptionConverter(toolCallExceptionConverter) + .build(); + } + +} diff --git a/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/ollama/OllamaAutoConfiguration.java b/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/ollama/OllamaAutoConfiguration.java index 18b7bc47cc..ecde0c935a 100644 --- a/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/ollama/OllamaAutoConfiguration.java +++ b/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/ollama/OllamaAutoConfiguration.java @@ -1,5 +1,5 @@ /* - * Copyright 2023-2024 the original author or authors. + * Copyright 2023-2025 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -16,15 +16,14 @@ package org.springframework.ai.autoconfigure.ollama; -import java.util.List; - import io.micrometer.observation.ObservationRegistry; +import org.springframework.ai.autoconfigure.chat.model.ToolCallingAutoConfiguration; import org.springframework.ai.chat.observation.ChatModelObservationConvention; import org.springframework.ai.embedding.observation.EmbeddingModelObservationConvention; import org.springframework.ai.model.function.DefaultFunctionCallbackResolver; -import org.springframework.ai.model.function.FunctionCallback; import org.springframework.ai.model.function.FunctionCallbackResolver; +import org.springframework.ai.model.tool.ToolCallingManager; import org.springframework.ai.ollama.OllamaChatModel; import org.springframework.ai.ollama.OllamaEmbeddingModel; import org.springframework.ai.ollama.api.OllamaApi; @@ -52,11 +51,12 @@ * @author Thomas Vitale * @since 0.8.0 */ -@AutoConfiguration(after = RestClientAutoConfiguration.class) +@AutoConfiguration(after = { RestClientAutoConfiguration.class, ToolCallingAutoConfiguration.class }) @ConditionalOnClass(OllamaApi.class) @EnableConfigurationProperties({ OllamaChatProperties.class, OllamaEmbeddingProperties.class, OllamaConnectionProperties.class, OllamaInitializationProperties.class }) -@ImportAutoConfiguration(classes = { RestClientAutoConfiguration.class, WebClientAutoConfiguration.class }) +@ImportAutoConfiguration(classes = { RestClientAutoConfiguration.class, ToolCallingAutoConfiguration.class, + WebClientAutoConfiguration.class }) public class OllamaAutoConfiguration { @Bean @@ -80,8 +80,8 @@ public OllamaApi ollamaApi(OllamaConnectionDetails connectionDetails, @ConditionalOnProperty(prefix = OllamaChatProperties.CONFIG_PREFIX, name = "enabled", havingValue = "true", matchIfMissing = true) public OllamaChatModel ollamaChatModel(OllamaApi ollamaApi, OllamaChatProperties properties, - OllamaInitializationProperties initProperties, List toolFunctionCallbacks, - FunctionCallbackResolver functionCallbackResolver, ObjectProvider observationRegistry, + OllamaInitializationProperties initProperties, ToolCallingManager toolCallingManager, + ObjectProvider observationRegistry, ObjectProvider observationConvention) { var chatModelPullStrategy = initProperties.getChat().isInclude() ? initProperties.getPullModelStrategy() : PullModelStrategy.NEVER; @@ -89,8 +89,7 @@ public OllamaChatModel ollamaChatModel(OllamaApi ollamaApi, OllamaChatProperties var chatModel = OllamaChatModel.builder() .ollamaApi(ollamaApi) .defaultOptions(properties.getOptions()) - .functionCallbackResolver(functionCallbackResolver) - .toolFunctionCallbacks(toolFunctionCallbacks) + .toolCallingManager(toolCallingManager) .observationRegistry(observationRegistry.getIfUnique(() -> ObservationRegistry.NOOP)) .modelManagementOptions( new ModelManagementOptions(chatModelPullStrategy, initProperties.getChat().getAdditionalModels(), diff --git a/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/chat/model/ToolCallingAutoConfigurationTests.java b/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/chat/model/ToolCallingAutoConfigurationTests.java new file mode 100644 index 0000000000..7c1d20214d --- /dev/null +++ b/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/chat/model/ToolCallingAutoConfigurationTests.java @@ -0,0 +1,53 @@ +/* + * Copyright 2023-2025 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.ai.autoconfigure.chat.model; + +import org.junit.jupiter.api.Test; +import org.springframework.ai.model.tool.DefaultToolCallingManager; +import org.springframework.ai.model.tool.ToolCallingManager; +import org.springframework.ai.tool.execution.DefaultToolCallExceptionConverter; +import org.springframework.ai.tool.execution.ToolCallExceptionConverter; +import org.springframework.ai.tool.resolution.DelegatingToolCallbackResolver; +import org.springframework.ai.tool.resolution.ToolCallbackResolver; +import org.springframework.boot.autoconfigure.AutoConfigurations; +import org.springframework.boot.test.context.runner.ApplicationContextRunner; + +import static org.assertj.core.api.Assertions.assertThat; + +/** + * Unit tests for {@link ToolCallingAutoConfiguration}. + * + * @author Thomas Vitale + */ +class ToolCallingAutoConfigurationTests { + + @Test + void beansAreCreated() { + new ApplicationContextRunner().withConfiguration(AutoConfigurations.of(ToolCallingAutoConfiguration.class)) + .run(context -> { + var toolCallbackResolver = context.getBean(ToolCallbackResolver.class); + assertThat(toolCallbackResolver).isInstanceOf(DelegatingToolCallbackResolver.class); + + var toolCallExceptionConverter = context.getBean(ToolCallExceptionConverter.class); + assertThat(toolCallExceptionConverter).isInstanceOf(DefaultToolCallExceptionConverter.class); + + var toolCallingManager = context.getBean(ToolCallingManager.class); + assertThat(toolCallingManager).isInstanceOf(DefaultToolCallingManager.class); + }); + } + +} diff --git a/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/ollama/OllamaImage.java b/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/ollama/OllamaImage.java index 0213febd09..807975a49f 100644 --- a/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/ollama/OllamaImage.java +++ b/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/ollama/OllamaImage.java @@ -18,7 +18,7 @@ public final class OllamaImage { - public static final String DEFAULT_IMAGE = "ollama/ollama:0.5.1"; + public static final String DEFAULT_IMAGE = "ollama/ollama:0.5.7"; private OllamaImage() { diff --git a/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/ollama/tool/OllamaFunctionToolBeanIT.java b/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/ollama/tool/OllamaFunctionToolBeanIT.java new file mode 100644 index 0000000000..10e8251356 --- /dev/null +++ b/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/ollama/tool/OllamaFunctionToolBeanIT.java @@ -0,0 +1,194 @@ +/* + * Copyright 2023-2025 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.ai.autoconfigure.ollama.tool; + +import java.util.List; +import java.util.function.Function; +import java.util.stream.Collectors; + +import org.junit.jupiter.api.BeforeAll; +import org.junit.jupiter.api.Test; +import reactor.core.publisher.Flux; + +import org.springframework.ai.autoconfigure.ollama.BaseOllamaIT; +import org.springframework.ai.autoconfigure.ollama.OllamaAutoConfiguration; +import org.springframework.ai.chat.messages.AssistantMessage; +import org.springframework.ai.chat.messages.UserMessage; +import org.springframework.ai.chat.model.ChatResponse; +import org.springframework.ai.chat.model.Generation; +import org.springframework.ai.chat.prompt.Prompt; +import org.springframework.ai.model.tool.ToolCallingChatOptions; +import org.springframework.ai.ollama.OllamaChatModel; +import org.springframework.ai.ollama.api.OllamaOptions; +import org.springframework.ai.tool.ToolCallbacks; +import org.springframework.ai.tool.annotation.Tool; +import org.springframework.boot.autoconfigure.AutoConfigurations; +import org.springframework.boot.test.context.runner.ApplicationContextRunner; +import org.springframework.context.annotation.Bean; +import org.springframework.context.annotation.Configuration; +import org.springframework.context.annotation.Description; +import org.springframework.core.log.LogAccessor; + +import static org.assertj.core.api.Assertions.assertThat; + +/** + * Integration tests for function-based tool calling in Ollama. + * + * @author Thomas Vitale + */ +public class OllamaFunctionToolBeanIT extends BaseOllamaIT { + + private static final LogAccessor logger = new LogAccessor(OllamaFunctionToolBeanIT.class); + + private static final String MODEL_NAME = "qwen2.5:3b"; + + private final ApplicationContextRunner contextRunner = new ApplicationContextRunner().withPropertyValues( + // @formatter:off + "spring.ai.ollama.baseUrl=" + getBaseUrl(), + "spring.ai.ollama.chat.options.model=" + MODEL_NAME, + "spring.ai.ollama.chat.options.temperature=0.5", + "spring.ai.ollama.chat.options.topK=10") + // @formatter:on + .withConfiguration(AutoConfigurations.of(OllamaAutoConfiguration.class)) + .withUserConfiguration(Config.class); + + @BeforeAll + public static void beforeAll() { + initializeOllama(MODEL_NAME); + } + + @Test + void toolCallTest() { + this.contextRunner.run(context -> { + + OllamaChatModel chatModel = context.getBean(OllamaChatModel.class); + + MyTools myTools = context.getBean(MyTools.class); + + UserMessage userMessage = new UserMessage( + "What are the weather conditions in San Francisco, Tokyo, and Paris? Find the temperature in Celsius for each of the three locations."); + + ChatResponse response = chatModel.call(new Prompt(List.of(userMessage), + OllamaOptions.builder().toolCallbacks(ToolCallbacks.from(myTools)).build())); + + logger.info("Response: " + response); + + assertThat(response.getResult().getOutput().getText()).contains("30", "10", "15"); + }); + + } + + @Test + void functionCallTest() { + this.contextRunner.run(context -> { + + OllamaChatModel chatModel = context.getBean(OllamaChatModel.class); + + UserMessage userMessage = new UserMessage( + "What are the weather conditions in San Francisco, Tokyo, and Paris? Find the temperature in Celsius for each of the three locations."); + + ChatResponse response = chatModel + .call(new Prompt(List.of(userMessage), OllamaOptions.builder().tools("weatherInfo").build())); + + logger.info("Response: " + response); + + assertThat(response.getResult().getOutput().getText()).contains("30", "10", "15"); + }); + } + + @Test + void streamFunctionCallTest() { + this.contextRunner.run(context -> { + + OllamaChatModel chatModel = context.getBean(OllamaChatModel.class); + + UserMessage userMessage = new UserMessage( + "What are the weather conditions in San Francisco, Tokyo, and Paris? Find the temperature in Celsius for each of the three locations."); + + Flux response = chatModel + .stream(new Prompt(List.of(userMessage), OllamaOptions.builder().function("weatherInfo").build())); + + String content = response.collectList() + .block() + .stream() + .map(ChatResponse::getResults) + .flatMap(List::stream) + .map(Generation::getOutput) + .map(AssistantMessage::getText) + .collect(Collectors.joining()); + logger.info("Response: " + content); + + assertThat(content).contains("30", "10", "15"); + }); + } + + @Test + void functionCallWithPortableFunctionCallingOptions() { + this.contextRunner.run(context -> { + + OllamaChatModel chatModel = context.getBean(OllamaChatModel.class); + + // Test weatherFunction + UserMessage userMessage = new UserMessage( + "What are the weather conditions in San Francisco, Tokyo, and Paris? Find the temperature in Celsius for each of the three locations."); + + ToolCallingChatOptions functionOptions = ToolCallingChatOptions.builder().tools("weatherInfo").build(); + + ChatResponse response = chatModel.call(new Prompt(List.of(userMessage), functionOptions)); + + logger.info("Response: " + response.getResult().getOutput().getText()); + + assertThat(response.getResult().getOutput().getText()).contains("30", "10", "15"); + }); + } + + static class MyTools { + + @Tool(description = "Find the weather conditions, and temperatures for a location, like a city or state.") + public String weatherByLocation(String locationName) { + int temperature = 0; + if (locationName.equals("San Francisco")) { + temperature = 30; + } + else if (locationName.equals("Tokyo")) { + temperature = 10; + } + else if (locationName.equals("Paris")) { + temperature = 15; + } + return "The temperature in " + locationName + " is " + temperature + " degrees Celsius."; + } + + } + + @Configuration + static class Config { + + @Bean + @Description("Find the weather conditions, forecasts, and temperatures for a location, like a city or state.") + public Function weatherInfo() { + return new MockWeatherService(); + } + + @Bean + public MyTools myTools() { + return new MyTools(); + } + + } + +} diff --git a/spring-ai-spring-boot-testcontainers/src/test/java/org/springframework/ai/testcontainers/service/connection/ollama/OllamaImage.java b/spring-ai-spring-boot-testcontainers/src/test/java/org/springframework/ai/testcontainers/service/connection/ollama/OllamaImage.java index 909dc22599..0c2703069a 100644 --- a/spring-ai-spring-boot-testcontainers/src/test/java/org/springframework/ai/testcontainers/service/connection/ollama/OllamaImage.java +++ b/spring-ai-spring-boot-testcontainers/src/test/java/org/springframework/ai/testcontainers/service/connection/ollama/OllamaImage.java @@ -23,7 +23,7 @@ */ public final class OllamaImage { - public static final DockerImageName DEFAULT_IMAGE = DockerImageName.parse("ollama/ollama:0.5.1"); + public static final DockerImageName DEFAULT_IMAGE = DockerImageName.parse("ollama/ollama:0.5.7"); private OllamaImage() {