From 4ac010607795cbb9bc7e84bdbbbcfcbd05803c44 Mon Sep 17 00:00:00 2001 From: mariofusco Date: Tue, 29 Oct 2024 17:25:01 +0100 Subject: [PATCH] Allow to rewrite LLM result in an OutputGuardrail --- .../guardrails/OutputGuardrailChainTest.java | 58 +++++++++++++++++++ ...drailOnStreamedResponseValidationTest.java | 21 +++++++ .../guardrails/GuardrailParams.java | 8 +++ .../guardrails/GuardrailResult.java | 16 +++++ .../guardrails/InputGuardrailParams.java | 5 ++ .../guardrails/OutputGuardrail.java | 8 +++ .../guardrails/OutputGuardrailParams.java | 9 +++ .../guardrails/OutputGuardrailResult.java | 34 +++++++++-- .../AiServiceMethodImplementationSupport.java | 4 ++ .../runtime/aiservice/GuardrailsSupport.java | 34 ++++++++--- 10 files changed, 184 insertions(+), 13 deletions(-) diff --git a/core/deployment/src/test/java/io/quarkiverse/langchain4j/test/guardrails/OutputGuardrailChainTest.java b/core/deployment/src/test/java/io/quarkiverse/langchain4j/test/guardrails/OutputGuardrailChainTest.java index 43c809a0d..b353f817e 100644 --- a/core/deployment/src/test/java/io/quarkiverse/langchain4j/test/guardrails/OutputGuardrailChainTest.java +++ b/core/deployment/src/test/java/io/quarkiverse/langchain4j/test/guardrails/OutputGuardrailChainTest.java @@ -1,6 +1,7 @@ package io.quarkiverse.langchain4j.test.guardrails; import static org.assertj.core.api.Assertions.assertThat; +import static org.assertj.core.api.Assertions.assertThatExceptionOfType; import java.util.List; import java.util.concurrent.atomic.AtomicInteger; @@ -28,6 +29,7 @@ import io.quarkiverse.langchain4j.guardrails.OutputGuardrail; import io.quarkiverse.langchain4j.guardrails.OutputGuardrailResult; import io.quarkiverse.langchain4j.guardrails.OutputGuardrails; +import io.quarkiverse.langchain4j.runtime.aiservice.GuardrailException; import io.quarkiverse.langchain4j.runtime.aiservice.NoopChatMemory; import io.quarkus.test.QuarkusUnitTest; @@ -78,6 +80,20 @@ void testThatRetryRestartTheChain() { assertThat(firstGuardrail.lastAccess()).isLessThan(secondGuardrail.lastAccess()); } + @Test + @ActivateRequestContext + void testThatRewritesTheOutputTwiceInTheChain() { + assertThat(aiService.rewritingSuccess("1", "foo")).isEqualTo("Hi!,1,2"); + } + + @Test + @ActivateRequestContext + void testThatRepromptAfterRewriteIsNotAllowed() { + assertThatExceptionOfType(GuardrailException.class) + .isThrownBy(() -> aiService.repromptAfterRewrite("1", "foo")) + .withMessageContaining("Retry or reprompt is not allowed after a rewritten output"); + } + @RegisterAiService(chatLanguageModelSupplier = MyChatModelSupplier.class, chatMemoryProviderSupplier = MyMemoryProviderSupplier.class) public interface MyAiService { @@ -90,6 +106,12 @@ public interface MyAiService { @OutputGuardrails({ FirstGuardrail.class, FailingGuardrail.class, SecondGuardrail.class }) String failingFirstTwo(@MemoryId String mem, @UserMessage String message); + @OutputGuardrails({ FirstRewritingGuardrail.class, SecondRewritingGuardrail.class }) + String rewritingSuccess(@MemoryId String mem, @UserMessage String message); + + @OutputGuardrails({ FirstRewritingGuardrail.class, RepromptingGuardrail.class }) + String repromptAfterRewrite(@MemoryId String mem, @UserMessage String message); + } @RequestScoped @@ -164,6 +186,42 @@ public int spy() { } } + @RequestScoped + public static class FirstRewritingGuardrail implements OutputGuardrail { + + @Override + public OutputGuardrailResult validate(AiMessage responseFromLLM) { + String text = responseFromLLM.text(); + return successWith(text + ",1"); + } + } + + @RequestScoped + public static class SecondRewritingGuardrail implements OutputGuardrail { + + @Override + public OutputGuardrailResult validate(AiMessage responseFromLLM) { + String text = responseFromLLM.text(); + return successWith(text + ",2"); + } + } + + @RequestScoped + public static class RepromptingGuardrail implements OutputGuardrail { + + private boolean firstCall = true; + + @Override + public OutputGuardrailResult validate(AiMessage responseFromLLM) { + if (firstCall) { + firstCall = false; + String text = responseFromLLM.text(); + return reprompt("Wrong message", text + ", " + text); + } + return success(); + } + } + public static class MyChatModelSupplier implements Supplier { @Override diff --git a/core/deployment/src/test/java/io/quarkiverse/langchain4j/test/guardrails/OutputGuardrailOnStreamedResponseValidationTest.java b/core/deployment/src/test/java/io/quarkiverse/langchain4j/test/guardrails/OutputGuardrailOnStreamedResponseValidationTest.java index baf41f082..c28180a85 100644 --- a/core/deployment/src/test/java/io/quarkiverse/langchain4j/test/guardrails/OutputGuardrailOnStreamedResponseValidationTest.java +++ b/core/deployment/src/test/java/io/quarkiverse/langchain4j/test/guardrails/OutputGuardrailOnStreamedResponseValidationTest.java @@ -139,6 +139,14 @@ void testFatalExceptionWithPassThroughAccumulator() { assertThat(fatal.spy()).isEqualTo(1); } + @Test + @ActivateRequestContext + void testRewritingWhileStreamingIsNotAllowed() { + assertThatThrownBy(() -> aiService.rewriting("1").collect().asList().await().indefinitely()) + .isInstanceOf(GuardrailException.class) + .hasMessageContaining("Attempting to rewrite the LLM output while streaming is not allowed"); + } + @RegisterAiService(streamingChatLanguageModelSupplier = MyChatModelSupplier.class, chatMemoryProviderSupplier = MyMemoryProviderSupplier.class) public interface MyAiService { @@ -187,6 +195,9 @@ public interface MyAiService { @OutputGuardrailAccumulator(PassThroughAccumulator.class) Multi fatalWithPassThroughAccumulator(@MemoryId String mem); + @UserMessage("Say Hi!") + @OutputGuardrails({ RewritingGuardrail.class }) + Multi rewriting(@MemoryId String mem); } @RequestScoped @@ -272,6 +283,16 @@ public int spy() { } } + @RequestScoped + public static class RewritingGuardrail implements OutputGuardrail { + + @Override + public OutputGuardrailResult validate(AiMessage responseFromLLM) { + String text = responseFromLLM.text(); + return successWith(text + ",1"); + } + } + public static class MyChatModelSupplier implements Supplier { @Override diff --git a/core/runtime/src/main/java/io/quarkiverse/langchain4j/guardrails/GuardrailParams.java b/core/runtime/src/main/java/io/quarkiverse/langchain4j/guardrails/GuardrailParams.java index 6706efbb3..e04fa76b8 100644 --- a/core/runtime/src/main/java/io/quarkiverse/langchain4j/guardrails/GuardrailParams.java +++ b/core/runtime/src/main/java/io/quarkiverse/langchain4j/guardrails/GuardrailParams.java @@ -18,4 +18,12 @@ public interface GuardrailParams { * @return the augmentation result, can be {@code null} */ AugmentationResult augmentationResult(); + + /** + * Recreate this guardrail param with the given input or output text. + * + * @param text The text of the rewritten param. + * @return A clone of this guardrail params with the given input or output text. + */ + GuardrailParams withText(String text); } diff --git a/core/runtime/src/main/java/io/quarkiverse/langchain4j/guardrails/GuardrailResult.java b/core/runtime/src/main/java/io/quarkiverse/langchain4j/guardrails/GuardrailResult.java index 1246f04aa..f731a17ea 100644 --- a/core/runtime/src/main/java/io/quarkiverse/langchain4j/guardrails/GuardrailResult.java +++ b/core/runtime/src/main/java/io/quarkiverse/langchain4j/guardrails/GuardrailResult.java @@ -15,6 +15,10 @@ enum Result { * A successful validation. */ SUCCESS, + /** + * A successful validation with a specific result. + */ + SUCCESS_WITH_RESULT, /** * A failed validation not preventing the subsequent validations eventually registered to be evaluated. */ @@ -27,6 +31,18 @@ enum Result { boolean isSuccess(); + default boolean isRewrittenResult() { + return false; + } + + default GuardrailResult blockRetry() { + throw new UnsupportedOperationException(); + } + + default String successfulResult() { + throw new UnsupportedOperationException(); + } + boolean isFatal(); /** diff --git a/core/runtime/src/main/java/io/quarkiverse/langchain4j/guardrails/InputGuardrailParams.java b/core/runtime/src/main/java/io/quarkiverse/langchain4j/guardrails/InputGuardrailParams.java index 1900d27b1..62bdcbfca 100644 --- a/core/runtime/src/main/java/io/quarkiverse/langchain4j/guardrails/InputGuardrailParams.java +++ b/core/runtime/src/main/java/io/quarkiverse/langchain4j/guardrails/InputGuardrailParams.java @@ -18,4 +18,9 @@ public record InputGuardrailParams(UserMessage userMessage, ChatMemory memory, AugmentationResult augmentationResult, String userMessageTemplate, Map variables) implements GuardrailParams { + + @Override + public InputGuardrailParams withText(String text) { + throw new UnsupportedOperationException(); + } } diff --git a/core/runtime/src/main/java/io/quarkiverse/langchain4j/guardrails/OutputGuardrail.java b/core/runtime/src/main/java/io/quarkiverse/langchain4j/guardrails/OutputGuardrail.java index 51487eca5..762b5478f 100644 --- a/core/runtime/src/main/java/io/quarkiverse/langchain4j/guardrails/OutputGuardrail.java +++ b/core/runtime/src/main/java/io/quarkiverse/langchain4j/guardrails/OutputGuardrail.java @@ -50,6 +50,14 @@ default OutputGuardrailResult success() { return OutputGuardrailResult.success(); } + /** + * @return The result of a successful output guardrail validation with a specific result. + * @param successfulResult The successful result. + */ + default OutputGuardrailResult successWith(String successfulResult) { + return OutputGuardrailResult.successWith(successfulResult); + } + /** * @param message A message describing the failure. * @return The result of a failed output guardrail validation. diff --git a/core/runtime/src/main/java/io/quarkiverse/langchain4j/guardrails/OutputGuardrailParams.java b/core/runtime/src/main/java/io/quarkiverse/langchain4j/guardrails/OutputGuardrailParams.java index 0162c5f5a..f8217561d 100644 --- a/core/runtime/src/main/java/io/quarkiverse/langchain4j/guardrails/OutputGuardrailParams.java +++ b/core/runtime/src/main/java/io/quarkiverse/langchain4j/guardrails/OutputGuardrailParams.java @@ -1,7 +1,9 @@ package io.quarkiverse.langchain4j.guardrails; +import java.util.List; import java.util.Map; +import dev.langchain4j.agent.tool.ToolExecutionRequest; import dev.langchain4j.data.message.AiMessage; import dev.langchain4j.memory.ChatMemory; import dev.langchain4j.rag.AugmentationResult; @@ -18,4 +20,11 @@ public record OutputGuardrailParams(AiMessage responseFromLLM, ChatMemory memory, AugmentationResult augmentationResult, String userMessageTemplate, Map variables) implements GuardrailParams { + + @Override + public OutputGuardrailParams withText(String text) { + List tools = responseFromLLM.toolExecutionRequests(); + AiMessage aiMessage = tools != null && !tools.isEmpty() ? new AiMessage(text, tools) : new AiMessage(text); + return new OutputGuardrailParams(aiMessage, memory, augmentationResult, userMessageTemplate, variables); + } } diff --git a/core/runtime/src/main/java/io/quarkiverse/langchain4j/guardrails/OutputGuardrailResult.java b/core/runtime/src/main/java/io/quarkiverse/langchain4j/guardrails/OutputGuardrailResult.java index 7c5bbf33d..2b139fa0d 100644 --- a/core/runtime/src/main/java/io/quarkiverse/langchain4j/guardrails/OutputGuardrailResult.java +++ b/core/runtime/src/main/java/io/quarkiverse/langchain4j/guardrails/OutputGuardrailResult.java @@ -10,35 +10,54 @@ * @param result The result of the output guardrail validation. * @param failures The list of failures, empty if the validation succeeded. */ -public record OutputGuardrailResult(Result result, List failures) implements GuardrailResult { +public record OutputGuardrailResult(Result result, String successfulResult, + List failures) implements GuardrailResult { private static final OutputGuardrailResult SUCCESS = new OutputGuardrailResult(); private OutputGuardrailResult() { - this(Result.SUCCESS, Collections.emptyList()); + this(Result.SUCCESS, null, Collections.emptyList()); + } + + private OutputGuardrailResult(String successfulResult) { + this(Result.SUCCESS_WITH_RESULT, successfulResult, Collections.emptyList()); } OutputGuardrailResult(List failures, boolean fatal) { - this(fatal ? Result.FATAL : Result.FAILURE, failures); + this(fatal ? Result.FATAL : Result.FAILURE, null, failures); } public static OutputGuardrailResult success() { return SUCCESS; } + public static OutputGuardrailResult successWith(String successfulResult) { + return new OutputGuardrailResult(successfulResult); + } + public static OutputGuardrailResult failure(List failures) { return new OutputGuardrailResult((List) failures, false); } @Override public boolean isSuccess() { - return result == Result.SUCCESS; + return result == Result.SUCCESS || result == Result.SUCCESS_WITH_RESULT; + } + + @Override + public boolean isRewrittenResult() { + return result == Result.SUCCESS_WITH_RESULT; } public boolean isRetry() { return !isSuccess() && failures.stream().anyMatch(Failure::retry); } + public OutputGuardrailResult blockRetry() { + failures().set(0, failures().get(0).blockRetry()); + return this; + } + public String getReprompt() { if (!isSuccess()) { for (Failure failure : failures) { @@ -97,6 +116,13 @@ public Failure withGuardrailClass(Class guardrailClass) { return new Failure(message(), cause(), guardrailClass, retry, reprompt); } + public Failure blockRetry() { + return retry + ? new Failure("Retry or reprompt is not allowed after a rewritten output", cause(), guardrailClass, false, + reprompt) + : this; + } + @Override public String toString() { return "The guardrail " + guardrailClass.getName() + " failed with this message: " + message; diff --git a/core/runtime/src/main/java/io/quarkiverse/langchain4j/runtime/aiservice/AiServiceMethodImplementationSupport.java b/core/runtime/src/main/java/io/quarkiverse/langchain4j/runtime/aiservice/AiServiceMethodImplementationSupport.java index f484949aa..a06cb9880 100644 --- a/core/runtime/src/main/java/io/quarkiverse/langchain4j/runtime/aiservice/AiServiceMethodImplementationSupport.java +++ b/core/runtime/src/main/java/io/quarkiverse/langchain4j/runtime/aiservice/AiServiceMethodImplementationSupport.java @@ -290,6 +290,10 @@ private List messagesToSend(ChatMessage augmentedUserMessage, throw new GuardrailsSupport.GuardrailRetryException(); } } else { + if (result.isRewrittenResult()) { + throw new GuardrailException( + "Attempting to rewrite the LLM output while streaming is not allowed"); + } return chunk; } }) diff --git a/core/runtime/src/main/java/io/quarkiverse/langchain4j/runtime/aiservice/GuardrailsSupport.java b/core/runtime/src/main/java/io/quarkiverse/langchain4j/runtime/aiservice/GuardrailsSupport.java index 477a2df1b..c4bb7315b 100644 --- a/core/runtime/src/main/java/io/quarkiverse/langchain4j/runtime/aiservice/GuardrailsSupport.java +++ b/core/runtime/src/main/java/io/quarkiverse/langchain4j/runtime/aiservice/GuardrailsSupport.java @@ -7,6 +7,7 @@ import jakarta.enterprise.inject.spi.CDI; +import dev.langchain4j.agent.tool.ToolExecutionRequest; import dev.langchain4j.agent.tool.ToolSpecification; import dev.langchain4j.data.message.AiMessage; import dev.langchain4j.data.message.UserMessage; @@ -57,8 +58,9 @@ public static Response invokeOutputGuardrails(AiServiceMethodCreateIn if (max <= 0) { max = 1; } + + OutputGuardrailResult result = null; while (attempt < max) { - OutputGuardrailResult result; try { result = invokeOutputGuardRails(methodCreateInfo, output); } catch (Exception e) { @@ -97,9 +99,20 @@ public static Response invokeOutputGuardrails(AiServiceMethodCreateIn if (attempt == max) { throw new GuardrailException("Output validation failed. The guardrails have reached the maximum number of retries"); } + + if (result.isRewrittenResult()) { + response = rewriteResponseWithText(response, result.successfulResult()); + } + return response; } + public static Response rewriteResponseWithText(Response response, String text) { + List tools = response.content().toolExecutionRequests(); + AiMessage content = tools != null && !tools.isEmpty() ? new AiMessage(text, tools) : new AiMessage(text); + return new Response<>(content, response.tokenUsage(), response.finishReason(), response.metadata()); + } + @SuppressWarnings("unchecked") private static OutputGuardrailResult invokeOutputGuardRails(AiServiceMethodCreateInfo methodCreateInfo, OutputGuardrailParams params) { @@ -160,7 +173,10 @@ private static GR guardrailResult(GuardrailParams p for (Class bean : classes) { GR result = (GR) CDI.current().select(bean).get().validate(params).validatedBy(bean); if (result.isFatal()) { - return result; + return accumulatedResults.isRewrittenResult() ? (GR) result.blockRetry() : result; + } + if (result.isRewrittenResult()) { + params = params.withText(result.successfulResult()); } accumulatedResults = compose(accumulatedResults, result, producer); } @@ -168,17 +184,17 @@ private static GR guardrailResult(GuardrailParams p return accumulatedResults; } - private static GR compose(GR first, GR second, + private static GR compose(GR oldResult, GR newResult, Function, GR> producer) { - if (first.isSuccess()) { - return second; + if (oldResult.isSuccess()) { + return newResult; } - if (second.isSuccess()) { - return first; + if (newResult.isSuccess()) { + return oldResult; } List failures = new ArrayList<>(); - failures.addAll(first.failures()); - failures.addAll(second.failures()); + failures.addAll(oldResult.failures()); + failures.addAll(newResult.failures()); return producer.apply(failures); }