Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Advancing Tool Support - Part 6 #2169

Closed
wants to merge 1 commit into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -30,8 +30,8 @@
import org.springframework.ai.model.function.FunctionCallingOptions;
import org.springframework.ai.tool.ToolCallback;
import org.springframework.ai.tool.definition.ToolDefinition;
import org.springframework.ai.tool.execution.DefaultToolCallExceptionConverter;
import org.springframework.ai.tool.execution.ToolCallExceptionConverter;
import org.springframework.ai.tool.execution.DefaultToolExecutionExceptionProcessor;
import org.springframework.ai.tool.execution.ToolExecutionExceptionProcessor;
import org.springframework.ai.tool.execution.ToolExecutionException;
import org.springframework.ai.tool.resolution.DelegatingToolCallbackResolver;
import org.springframework.ai.tool.resolution.ToolCallbackResolver;
Expand Down Expand Up @@ -62,26 +62,26 @@ public class DefaultToolCallingManager implements ToolCallingManager {
private static final ToolCallbackResolver DEFAULT_TOOL_CALLBACK_RESOLVER
= new DelegatingToolCallbackResolver(List.of());

private static final ToolCallExceptionConverter DEFAULT_TOOL_CALL_EXCEPTION_CONVERTER
= DefaultToolCallExceptionConverter.builder().build();
private static final ToolExecutionExceptionProcessor DEFAULT_TOOL_EXECUTION_EXCEPTION_PROCESSOR
= DefaultToolExecutionExceptionProcessor.builder().build();

// @formatter:on

private final ObservationRegistry observationRegistry;

private final ToolCallbackResolver toolCallbackResolver;

private final ToolCallExceptionConverter toolCallExceptionConverter;
private final ToolExecutionExceptionProcessor toolExecutionExceptionProcessor;

public DefaultToolCallingManager(ObservationRegistry observationRegistry, ToolCallbackResolver toolCallbackResolver,
ToolCallExceptionConverter toolCallExceptionConverter) {
ToolExecutionExceptionProcessor toolExecutionExceptionProcessor) {
Assert.notNull(observationRegistry, "observationRegistry cannot be null");
Assert.notNull(toolCallbackResolver, "toolCallbackResolver cannot be null");
Assert.notNull(toolCallExceptionConverter, "toolCallExceptionConverter cannot be null");
Assert.notNull(toolExecutionExceptionProcessor, "toolCallExceptionConverter cannot be null");

this.observationRegistry = observationRegistry;
this.toolCallbackResolver = toolCallbackResolver;
this.toolCallExceptionConverter = toolCallExceptionConverter;
this.toolExecutionExceptionProcessor = toolExecutionExceptionProcessor;
}

@Override
Expand Down Expand Up @@ -214,7 +214,7 @@ else if (toolCallback instanceof ToolCallback callback) {
toolResult = toolCallback.call(toolInputArguments, toolContext);
}
catch (ToolExecutionException ex) {
toolResult = toolCallExceptionConverter.convert(ex);
toolResult = toolExecutionExceptionProcessor.process(ex);
}

toolResponses.add(new ToolResponseMessage.ToolResponse(toolCall.id(), toolName, toolResult));
Expand Down Expand Up @@ -244,7 +244,7 @@ public static class Builder {

private ToolCallbackResolver toolCallbackResolver = DEFAULT_TOOL_CALLBACK_RESOLVER;

private ToolCallExceptionConverter toolCallExceptionConverter = DEFAULT_TOOL_CALL_EXCEPTION_CONVERTER;
private ToolExecutionExceptionProcessor toolExecutionExceptionProcessor = DEFAULT_TOOL_EXECUTION_EXCEPTION_PROCESSOR;

private Builder() {
}
Expand All @@ -259,13 +259,15 @@ public Builder toolCallbackResolver(ToolCallbackResolver toolCallbackResolver) {
return this;
}

public Builder toolCallExceptionConverter(ToolCallExceptionConverter toolCallExceptionConverter) {
this.toolCallExceptionConverter = toolCallExceptionConverter;
public Builder toolExecutionExceptionProcessor(
ToolExecutionExceptionProcessor toolExecutionExceptionProcessor) {
this.toolExecutionExceptionProcessor = toolExecutionExceptionProcessor;
return this;
}

public DefaultToolCallingManager build() {
return new DefaultToolCallingManager(observationRegistry, toolCallbackResolver, toolCallExceptionConverter);
return new DefaultToolCallingManager(observationRegistry, toolCallbackResolver,
toolExecutionExceptionProcessor);
}

}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -29,8 +29,8 @@
import org.springframework.ai.model.function.FunctionCallingOptions;
import org.springframework.ai.tool.ToolCallback;
import org.springframework.ai.tool.definition.ToolDefinition;
import org.springframework.ai.tool.execution.DefaultToolCallExceptionConverter;
import org.springframework.ai.tool.execution.ToolCallExceptionConverter;
import org.springframework.ai.tool.execution.DefaultToolExecutionExceptionProcessor;
import org.springframework.ai.tool.execution.ToolExecutionExceptionProcessor;
import org.springframework.ai.tool.execution.ToolExecutionException;
import org.springframework.lang.Nullable;
import org.springframework.util.Assert;
Expand Down Expand Up @@ -59,7 +59,8 @@ public class LegacyToolCallingManager implements ToolCallingManager {

private final Map<String, FunctionCallback> functionCallbacks = new HashMap<>();

private final ToolCallExceptionConverter toolCallExceptionConverter = DefaultToolCallExceptionConverter.builder()
private final ToolExecutionExceptionProcessor toolExecutionExceptionProcessor = DefaultToolExecutionExceptionProcessor
.builder()
.build();

public LegacyToolCallingManager(@Nullable FunctionCallbackResolver functionCallbackResolver,
Expand Down Expand Up @@ -194,7 +195,7 @@ else if (prompt.getOptions() instanceof FunctionCallingOptions functionOptions)
toolResult = toolCallback.call(toolInputArguments, toolContext);
}
catch (ToolExecutionException ex) {
toolResult = toolCallExceptionConverter.convert(ex);
toolResult = toolExecutionExceptionProcessor.process(ex);
}

toolResponses.add(new ToolResponseMessage.ToolResponse(toolCall.id(), toolName, toolResult));
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,9 +16,11 @@

package org.springframework.ai.tool;

import org.springframework.ai.chat.model.ToolContext;
import org.springframework.ai.model.function.FunctionCallback;
import org.springframework.ai.tool.definition.ToolDefinition;
import org.springframework.ai.tool.metadata.ToolMetadata;
import org.springframework.lang.Nullable;

/**
* Represents a tool whose execution can be triggered by an AI model.
Expand All @@ -40,6 +42,23 @@ default ToolMetadata getToolMetadata() {
return ToolMetadata.builder().build();
}

/**
* Execute tool with the given input and return the result to send back to the AI
* model.
*/
String call(String toolInput);

/**
* Execute tool with the given input and context, and return the result to send back
* to the AI model.
*/
default String call(String toolInput, @Nullable ToolContext tooContext) {
if (tooContext != null && !tooContext.getContext().isEmpty()) {
throw new UnsupportedOperationException("Tool context is not supported!");
}
return call(toolInput);
}

@Override
@Deprecated // Call getToolDefinition().name() instead
default String getName() {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ public final class DefaultToolCallResultConverter implements ToolCallResultConve
private static final Logger logger = LoggerFactory.getLogger(DefaultToolCallResultConverter.class);

@Override
public String apply(@Nullable Object result, @Nullable Type returnType) {
public String convert(@Nullable Object result, @Nullable Type returnType) {
if (returnType == Void.TYPE) {
logger.debug("The tool has no return type. Converting to conventional response.");
return "Done";
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,25 +21,25 @@
import org.springframework.util.Assert;

/**
* Default implementation of {@link ToolCallExceptionConverter}.
* Default implementation of {@link ToolExecutionExceptionProcessor}.
*
* @author Thomas Vitale
* @since 1.0.0
*/
public class DefaultToolCallExceptionConverter implements ToolCallExceptionConverter {
public class DefaultToolExecutionExceptionProcessor implements ToolExecutionExceptionProcessor {

private final static Logger logger = LoggerFactory.getLogger(DefaultToolCallExceptionConverter.class);
private final static Logger logger = LoggerFactory.getLogger(DefaultToolExecutionExceptionProcessor.class);

private static final boolean DEFAULT_ALWAYS_THROW = false;

private final boolean alwaysThrow;

public DefaultToolCallExceptionConverter(boolean alwaysThrow) {
public DefaultToolExecutionExceptionProcessor(boolean alwaysThrow) {
this.alwaysThrow = alwaysThrow;
}

@Override
public String convert(ToolExecutionException exception) {
public String process(ToolExecutionException exception) {
Assert.notNull(exception, "exception cannot be null");
if (alwaysThrow) {
throw exception;
Expand All @@ -62,8 +62,8 @@ public Builder alwaysThrow(boolean alwaysThrow) {
return this;
}

public DefaultToolCallExceptionConverter build() {
return new DefaultToolCallExceptionConverter(alwaysThrow);
public DefaultToolExecutionExceptionProcessor build() {
return new DefaultToolExecutionExceptionProcessor(alwaysThrow);
}

}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,6 @@
import org.springframework.lang.Nullable;

import java.lang.reflect.Type;
import java.util.function.BiFunction;

/**
* A functional interface to convert tool call results to a String that can be sent back
Expand All @@ -29,12 +28,12 @@
* @since 1.0.0
*/
@FunctionalInterface
public interface ToolCallResultConverter extends BiFunction<Object, Type, String> {
public interface ToolCallResultConverter {

/**
* Given an Object returned by a tool, convert it to a String compatible with the
* given class type.
*/
String apply(@Nullable Object result, @Nullable Type returnType);
String convert(@Nullable Object result, @Nullable Type returnType);

}
Original file line number Diff line number Diff line change
Expand Up @@ -17,19 +17,20 @@
package org.springframework.ai.tool.execution;

/**
* A functional interface to convert a tool call exception to a String that can be sent
* back to the AI model.
* A functional interface to process a {@link ToolExecutionException} by either converting
* the error message to a String that can be sent back to the AI model or throwing an
* exception to be handled by the caller.
*
* @author Thomas Vitale
* @since 1.0.0
*/
@FunctionalInterface
public interface ToolCallExceptionConverter {
public interface ToolExecutionExceptionProcessor {

/**
* Convert an exception thrown by a tool to a String that can be sent back to the AI
* model.
* model or throw an exception to be handled by the caller.
*/
String convert(ToolExecutionException exception);
String process(ToolExecutionException exception);

}
Original file line number Diff line number Diff line change
Expand Up @@ -102,7 +102,7 @@ public String call(String toolInput, @Nullable ToolContext toolContext) {

logger.debug("Successful execution of tool: {}", toolDefinition.name());

return toolCallResultConverter.apply(response, null);
return toolCallResultConverter.convert(response, null);
}

@Override
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,6 @@
import org.springframework.util.Assert;
import org.springframework.util.ClassUtils;
import org.springframework.util.CollectionUtils;
import org.springframework.util.ReflectionUtils;

import java.lang.reflect.InvocationTargetException;
import java.lang.reflect.Method;
Expand Down Expand Up @@ -112,7 +111,7 @@ public String call(String toolInput, @Nullable ToolContext toolContext) {

Type returnType = toolMethod.getGenericReturnType();

return toolCallResultConverter.apply(result, returnType);
return toolCallResultConverter.convert(result, returnType);
}

private void validateToolContextSupport(@Nullable ToolContext toolContext) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
package org.springframework.ai.util.json.schema;

import com.fasterxml.jackson.annotation.JsonProperty;
import com.fasterxml.jackson.annotation.JsonPropertyDescription;
import com.fasterxml.jackson.databind.JsonNode;
import com.fasterxml.jackson.databind.node.ObjectNode;
import com.github.victools.jsonschema.generator.Module;
Expand Down Expand Up @@ -54,6 +55,8 @@
* <ul>
* <li>{@code @ToolParam(required = ..., description = ...)}</li>
* <li>{@code @JsonProperty(required = ...)}</li>
* <li>{@code @JsonClassDescription(...)}</li>
* <li>{@code @JsonPropertyDescription(...)}</li>
* <li>{@code @Schema(required = ..., description = ...)}</li>
* <li>{@code @Nullable}</li>
* </ul>
Expand Down Expand Up @@ -165,13 +168,19 @@ private static void processSchemaOptions(SchemaOption[] schemaOptions, ObjectNod
}

/**
* Determines whether a property is required based on the presence of a series of
* Determines whether a property is required based on the presence of a series of *
* annotations.
*
* <p>
* - {@code @ToolParam(required = ...)} - {@code @JsonProperty(required = ...)} -
* {@code @Schema(required = ...)}
* <ul>
* <li>{@code @ToolParam(required = ...)}</li>
* <li>{@code @JsonProperty(required = ...)}</li>
* <li>{@code @Schema(required = ...)}</li>
* <li>{@code @Nullable}</li>
* </ul>
* <p>
* If none of these annotations are present, the default behavior is to consider the
*
* If none of these annotations are present, the default behavior is to consider the *
* property as required.
*/
private static boolean isMethodParameterRequired(Method method, int index) {
Expand Down Expand Up @@ -201,6 +210,17 @@ private static boolean isMethodParameterRequired(Method method, int index) {
return PROPERTY_REQUIRED_BY_DEFAULT;
}

/**
* Determines a property description based on the presence of a series of annotations.
*
* <p>
* <ul>
* <li>{@code @ToolParam(description = ...)}</li>
* <li>{@code @JsonPropertyDescription(...)}</li>
* <li>{@code @Schema(description = ...)}</li>
* </ul>
* <p>
*/
@Nullable
private static String getMethodParameterDescription(Method method, int index) {
Parameter parameter = method.getParameters()[index];
Expand All @@ -210,6 +230,11 @@ private static String getMethodParameterDescription(Method method, int index) {
return toolParamAnnotation.description();
}

var jacksonAnnotation = parameter.getAnnotation(JsonPropertyDescription.class);
if (jacksonAnnotation != null && StringUtils.hasText(jacksonAnnotation.value())) {
return jacksonAnnotation.value();
}

var schemaAnnotation = parameter.getAnnotation(Schema.class);
if (schemaAnnotation != null && StringUtils.hasText(schemaAnnotation.description())) {
return schemaAnnotation.description();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@
import org.springframework.ai.chat.prompt.Prompt;
import org.springframework.ai.tool.ToolCallback;
import org.springframework.ai.tool.definition.ToolDefinition;
import org.springframework.ai.tool.execution.ToolCallExceptionConverter;
import org.springframework.ai.tool.execution.ToolExecutionExceptionProcessor;
import org.springframework.ai.tool.execution.ToolExecutionException;
import org.springframework.ai.tool.metadata.ToolMetadata;
import org.springframework.ai.tool.resolution.StaticToolCallbackResolver;
Expand Down Expand Up @@ -59,7 +59,7 @@ void whenObservationRegistryIsNullThenThrow() {
assertThatThrownBy(() -> DefaultToolCallingManager.builder()
.observationRegistry(null)
.toolCallbackResolver(mock(ToolCallbackResolver.class))
.toolCallExceptionConverter(mock(ToolCallExceptionConverter.class))
.toolExecutionExceptionProcessor(mock(ToolExecutionExceptionProcessor.class))
.build()).isInstanceOf(IllegalArgumentException.class).hasMessage("observationRegistry cannot be null");
}

Expand All @@ -68,7 +68,7 @@ void whenToolCallbackResolverIsNullThenThrow() {
assertThatThrownBy(() -> DefaultToolCallingManager.builder()
.observationRegistry(mock(ObservationRegistry.class))
.toolCallbackResolver(null)
.toolCallExceptionConverter(mock(ToolCallExceptionConverter.class))
.toolExecutionExceptionProcessor(mock(ToolExecutionExceptionProcessor.class))
.build()).isInstanceOf(IllegalArgumentException.class).hasMessage("toolCallbackResolver cannot be null");
}

Expand All @@ -77,7 +77,7 @@ void whenToolCallExceptionConverterIsNullThenThrow() {
assertThatThrownBy(() -> DefaultToolCallingManager.builder()
.observationRegistry(mock(ObservationRegistry.class))
.toolCallbackResolver(mock(ToolCallbackResolver.class))
.toolCallExceptionConverter(null)
.toolExecutionExceptionProcessor(null)
.build()).isInstanceOf(IllegalArgumentException.class)
.hasMessage("toolCallExceptionConverter cannot be null");
}
Expand Down
Loading