diff --git a/core/deployment/src/test/java/io/quarkiverse/langchain4j/test/guardrails/GuardrailWithAugmentationTest.java b/core/deployment/src/test/java/io/quarkiverse/langchain4j/test/guardrails/GuardrailWithAugmentationTest.java index 0a4541ac0..205d1867f 100644 --- a/core/deployment/src/test/java/io/quarkiverse/langchain4j/test/guardrails/GuardrailWithAugmentationTest.java +++ b/core/deployment/src/test/java/io/quarkiverse/langchain4j/test/guardrails/GuardrailWithAugmentationTest.java @@ -33,8 +33,10 @@ import dev.langchain4j.service.UserMessage; import io.quarkiverse.langchain4j.RegisterAiService; import io.quarkiverse.langchain4j.guardrails.InputGuardrail; +import io.quarkiverse.langchain4j.guardrails.InputGuardrailResult; import io.quarkiverse.langchain4j.guardrails.InputGuardrails; import io.quarkiverse.langchain4j.guardrails.OutputGuardrail; +import io.quarkiverse.langchain4j.guardrails.OutputGuardrailResult; import io.quarkiverse.langchain4j.guardrails.OutputGuardrails; import io.quarkus.test.QuarkusUnitTest; import io.smallrye.mutiny.Multi; @@ -121,9 +123,10 @@ public static class MyInputGuardrail implements InputGuardrail { AtomicInteger spy = new AtomicInteger(); @Override - public void validate(InputGuardrailParams params) throws ValidationException { + public InputGuardrailResult validate(InputGuardrailParams params) { spy.incrementAndGet(); assertThat(params.augmentationResult().contents()).hasSize(2); + return InputGuardrailResult.success(); } public int getSpy() { @@ -137,9 +140,10 @@ public static class MyOutputGuardrail implements OutputGuardrail { AtomicInteger spy = new AtomicInteger(); @Override - public void validate(OutputGuardrailParams params) throws ValidationException { + public OutputGuardrailResult validate(OutputGuardrailParams params) { spy.incrementAndGet(); assertThat(params.augmentationResult().contents()).hasSize(2); + return OutputGuardrailResult.success(); } public int getSpy() { diff --git a/core/deployment/src/test/java/io/quarkiverse/langchain4j/test/guardrails/InputAndOutputGuardrailsTest.java b/core/deployment/src/test/java/io/quarkiverse/langchain4j/test/guardrails/InputAndOutputGuardrailsTest.java index 366144fa2..6c4d32b59 100644 --- a/core/deployment/src/test/java/io/quarkiverse/langchain4j/test/guardrails/InputAndOutputGuardrailsTest.java +++ b/core/deployment/src/test/java/io/quarkiverse/langchain4j/test/guardrails/InputAndOutputGuardrailsTest.java @@ -27,8 +27,10 @@ import dev.langchain4j.service.UserMessage; import io.quarkiverse.langchain4j.RegisterAiService; import io.quarkiverse.langchain4j.guardrails.InputGuardrail; +import io.quarkiverse.langchain4j.guardrails.InputGuardrailResult; import io.quarkiverse.langchain4j.guardrails.InputGuardrails; import io.quarkiverse.langchain4j.guardrails.OutputGuardrail; +import io.quarkiverse.langchain4j.guardrails.OutputGuardrailResult; import io.quarkiverse.langchain4j.guardrails.OutputGuardrails; import io.quarkus.test.QuarkusUnitTest; @@ -38,7 +40,8 @@ public class InputAndOutputGuardrailsTest { static final QuarkusUnitTest unitTest = new QuarkusUnitTest() .setArchiveProducer(() -> ShrinkWrap.create(JavaArchive.class) .addClasses(MyAiService.class, - MyChatModel.class, MyChatModelSupplier.class, MyMemoryProviderSupplier.class)); + MyChatModel.class, MyChatModelSupplier.class, MyMemoryProviderSupplier.class, + ValidationException.class)); @Inject MyOkInputGuardrail okIn; @@ -144,8 +147,9 @@ public static class MyOkInputGuardrail implements InputGuardrail { AtomicInteger spy = new AtomicInteger(); @Override - public void validate(InputGuardrailParams params) throws ValidationException { + public InputGuardrailResult validate(InputGuardrailParams params) { spy.incrementAndGet(); + return success(); } public int getSpy() { @@ -159,9 +163,9 @@ public static class MyKoInputGuardrail implements InputGuardrail { AtomicInteger spy = new AtomicInteger(); @Override - public void validate(InputGuardrailParams params) throws ValidationException { + public InputGuardrailResult validate(InputGuardrailParams params) { spy.incrementAndGet(); - throw new ValidationException("boom"); + return failure("boom", new ValidationException("boom")); } public int getSpy() { @@ -175,8 +179,9 @@ public static class MyOkOutputGuardrail implements OutputGuardrail { AtomicInteger spy = new AtomicInteger(); @Override - public void validate(OutputGuardrailParams params) throws ValidationException { + public OutputGuardrailResult validate(OutputGuardrailParams params) { spy.incrementAndGet(); + return OutputGuardrailResult.success(); } public int getSpy() { @@ -190,9 +195,9 @@ public static class MyKoOutputGuardrail implements OutputGuardrail { AtomicInteger spy = new AtomicInteger(); @Override - public void validate(OutputGuardrailParams params) throws ValidationException { + public OutputGuardrailResult validate(OutputGuardrailParams params) { spy.incrementAndGet(); - throw new ValidationException("boom", false, null); + return failure("boom", new ValidationException("boom")); } public int getSpy() { @@ -206,10 +211,11 @@ public static class MyKoWithRetryOutputGuardrail implements OutputGuardrail { AtomicInteger spy = new AtomicInteger(); @Override - public void validate(OutputGuardrailParams params) throws ValidationException { + public OutputGuardrailResult validate(OutputGuardrailParams params) { if (spy.incrementAndGet() == 1) { - throw new ValidationException("KO", true, null); + return retry("KO"); } + return success(); } public int getSpy() { @@ -223,10 +229,11 @@ public static class MyKoWithRepromprOutputGuardrail implements OutputGuardrail { AtomicInteger spy = new AtomicInteger(); @Override - public void validate(OutputGuardrailParams params) throws ValidationException { + public OutputGuardrailResult validate(OutputGuardrailParams params) { if (spy.incrementAndGet() == 1) { - throw new ValidationException("KO", true, "retry"); + return reprompt("KO", "retry"); } + return success(); } public int getSpy() { diff --git a/core/deployment/src/test/java/io/quarkiverse/langchain4j/test/guardrails/InputGuardrailChainTest.java b/core/deployment/src/test/java/io/quarkiverse/langchain4j/test/guardrails/InputGuardrailChainTest.java index 467c886e2..63687b7fc 100644 --- a/core/deployment/src/test/java/io/quarkiverse/langchain4j/test/guardrails/InputGuardrailChainTest.java +++ b/core/deployment/src/test/java/io/quarkiverse/langchain4j/test/guardrails/InputGuardrailChainTest.java @@ -26,9 +26,10 @@ import dev.langchain4j.service.MemoryId; import dev.langchain4j.service.UserMessage; import io.quarkiverse.langchain4j.RegisterAiService; -import io.quarkiverse.langchain4j.guardrails.GuardrailException; import io.quarkiverse.langchain4j.guardrails.InputGuardrail; +import io.quarkiverse.langchain4j.guardrails.InputGuardrailResult; import io.quarkiverse.langchain4j.guardrails.InputGuardrails; +import io.quarkiverse.langchain4j.runtime.aiservice.GuardrailException; import io.quarkiverse.langchain4j.runtime.aiservice.NoopChatMemory; import io.quarkus.test.QuarkusUnitTest; @@ -38,7 +39,8 @@ public class InputGuardrailChainTest { static final QuarkusUnitTest unitTest = new QuarkusUnitTest() .setArchiveProducer(() -> ShrinkWrap.create(JavaArchive.class) .addClasses(MyAiService.class, - MyChatModel.class, MyChatModelSupplier.class, MyMemoryProviderSupplier.class)); + MyChatModel.class, MyChatModelSupplier.class, MyMemoryProviderSupplier.class, + ValidationException.class)); @Inject MyAiService aiService; @@ -74,7 +76,7 @@ void testThatGuardrailOrderIsCorrect() { void testFailureTheChain() { assertThatThrownBy(() -> aiService.failingFirstTwo("1", "foo")) .isInstanceOf(GuardrailException.class) - .hasCauseInstanceOf(InputGuardrail.ValidationException.class) + .hasCauseInstanceOf(ValidationException.class) .hasRootCauseMessage("boom"); assertThat(firstGuardrail.spy()).isEqualTo(1); assertThat(secondGuardrail.spy()).isEqualTo(0); @@ -102,7 +104,7 @@ public static class FirstGuardrail implements InputGuardrail { AtomicLong lastAccess = new AtomicLong(); @Override - public void validate(dev.langchain4j.data.message.UserMessage um) { + public InputGuardrailResult validate(dev.langchain4j.data.message.UserMessage um) { spy.incrementAndGet(); lastAccess.set(System.nanoTime()); try { @@ -110,6 +112,7 @@ public void validate(dev.langchain4j.data.message.UserMessage um) { } catch (InterruptedException e) { // Ignore me } + return success(); } public int spy() { @@ -128,7 +131,7 @@ public static class SecondGuardrail implements InputGuardrail { volatile AtomicLong lastAccess = new AtomicLong(); @Override - public void validate(dev.langchain4j.data.message.UserMessage um) { + public InputGuardrailResult validate(dev.langchain4j.data.message.UserMessage um) { spy.incrementAndGet(); lastAccess.set(System.nanoTime()); try { @@ -136,6 +139,7 @@ public void validate(dev.langchain4j.data.message.UserMessage um) { } catch (InterruptedException e) { // Ignore me } + return success(); } public int spy() { @@ -153,10 +157,11 @@ public static class FailingGuardrail implements InputGuardrail { AtomicInteger spy = new AtomicInteger(0); @Override - public void validate(dev.langchain4j.data.message.UserMessage um) throws ValidationException { + public InputGuardrailResult validate(dev.langchain4j.data.message.UserMessage um) { if (spy.incrementAndGet() == 1) { - throw new ValidationException("boom"); + return fatal("boom", new ValidationException("boom")); } + return success(); } public int spy() { diff --git a/core/deployment/src/test/java/io/quarkiverse/langchain4j/test/guardrails/InputGuardrailNotFoundTest.java b/core/deployment/src/test/java/io/quarkiverse/langchain4j/test/guardrails/InputGuardrailNotFoundTest.java index c404361a5..f8d908649 100644 --- a/core/deployment/src/test/java/io/quarkiverse/langchain4j/test/guardrails/InputGuardrailNotFoundTest.java +++ b/core/deployment/src/test/java/io/quarkiverse/langchain4j/test/guardrails/InputGuardrailNotFoundTest.java @@ -24,6 +24,7 @@ import dev.langchain4j.service.UserMessage; import io.quarkiverse.langchain4j.RegisterAiService; import io.quarkiverse.langchain4j.guardrails.InputGuardrail; +import io.quarkiverse.langchain4j.guardrails.InputGuardrailResult; import io.quarkiverse.langchain4j.guardrails.InputGuardrails; import io.quarkiverse.langchain4j.runtime.aiservice.NoopChatMemory; import io.quarkus.test.QuarkusUnitTest; @@ -59,7 +60,7 @@ public interface MyAiService { public static class MissingGuardRail implements InputGuardrail { @Override - public void validate(dev.langchain4j.data.message.UserMessage um) { + public InputGuardrailResult validate(dev.langchain4j.data.message.UserMessage um) { throw new RuntimeException("Should not be invoked"); } diff --git a/core/deployment/src/test/java/io/quarkiverse/langchain4j/test/guardrails/InputGuardrailOnClassAndMethodTest.java b/core/deployment/src/test/java/io/quarkiverse/langchain4j/test/guardrails/InputGuardrailOnClassAndMethodTest.java index 5f7bc3e29..0f53dcb22 100644 --- a/core/deployment/src/test/java/io/quarkiverse/langchain4j/test/guardrails/InputGuardrailOnClassAndMethodTest.java +++ b/core/deployment/src/test/java/io/quarkiverse/langchain4j/test/guardrails/InputGuardrailOnClassAndMethodTest.java @@ -25,6 +25,7 @@ import dev.langchain4j.service.UserMessage; import io.quarkiverse.langchain4j.RegisterAiService; import io.quarkiverse.langchain4j.guardrails.InputGuardrail; +import io.quarkiverse.langchain4j.guardrails.InputGuardrailResult; import io.quarkiverse.langchain4j.guardrails.InputGuardrails; import io.quarkiverse.langchain4j.runtime.aiservice.NoopChatMemory; import io.quarkus.test.QuarkusUnitTest; @@ -74,8 +75,9 @@ public static class OKGuardrail implements InputGuardrail { AtomicInteger spy = new AtomicInteger(0); @Override - public void validate(dev.langchain4j.data.message.UserMessage um) { + public InputGuardrailResult validate(dev.langchain4j.data.message.UserMessage um) { spy.incrementAndGet(); + return success(); } public int spy() { @@ -89,9 +91,9 @@ public static class KOGuardrail implements InputGuardrail { AtomicInteger spy = new AtomicInteger(0); @Override - public void validate(dev.langchain4j.data.message.UserMessage um) throws ValidationException { + public InputGuardrailResult validate(dev.langchain4j.data.message.UserMessage um) { spy.incrementAndGet(); - throw new ValidationException("KO"); + return failure("KO"); } public int spy() { diff --git a/core/deployment/src/test/java/io/quarkiverse/langchain4j/test/guardrails/InputGuardrailOnClassTest.java b/core/deployment/src/test/java/io/quarkiverse/langchain4j/test/guardrails/InputGuardrailOnClassTest.java index ff25f8124..deaec7dbd 100644 --- a/core/deployment/src/test/java/io/quarkiverse/langchain4j/test/guardrails/InputGuardrailOnClassTest.java +++ b/core/deployment/src/test/java/io/quarkiverse/langchain4j/test/guardrails/InputGuardrailOnClassTest.java @@ -25,6 +25,7 @@ import dev.langchain4j.service.UserMessage; import io.quarkiverse.langchain4j.RegisterAiService; import io.quarkiverse.langchain4j.guardrails.InputGuardrail; +import io.quarkiverse.langchain4j.guardrails.InputGuardrailResult; import io.quarkiverse.langchain4j.guardrails.InputGuardrails; import io.quarkiverse.langchain4j.runtime.aiservice.NoopChatMemory; import io.quarkus.test.QuarkusUnitTest; @@ -68,8 +69,9 @@ public static class OKGuardrail implements InputGuardrail { AtomicInteger spy = new AtomicInteger(0); @Override - public void validate(dev.langchain4j.data.message.UserMessage ignored) { + public InputGuardrailResult validate(dev.langchain4j.data.message.UserMessage ignored) { spy.incrementAndGet(); + return success(); } public int spy() { diff --git a/core/deployment/src/test/java/io/quarkiverse/langchain4j/test/guardrails/InputGuardrailTest.java b/core/deployment/src/test/java/io/quarkiverse/langchain4j/test/guardrails/InputGuardrailTest.java index 136f8bf2f..fed79df4a 100644 --- a/core/deployment/src/test/java/io/quarkiverse/langchain4j/test/guardrails/InputGuardrailTest.java +++ b/core/deployment/src/test/java/io/quarkiverse/langchain4j/test/guardrails/InputGuardrailTest.java @@ -27,6 +27,7 @@ import dev.langchain4j.service.UserMessage; import io.quarkiverse.langchain4j.RegisterAiService; import io.quarkiverse.langchain4j.guardrails.InputGuardrail; +import io.quarkiverse.langchain4j.guardrails.InputGuardrailResult; import io.quarkiverse.langchain4j.guardrails.InputGuardrails; import io.quarkiverse.langchain4j.runtime.aiservice.NoopChatMemory; import io.quarkus.arc.Arc; @@ -38,7 +39,8 @@ public class InputGuardrailTest { static final QuarkusUnitTest unitTest = new QuarkusUnitTest() .setArchiveProducer(() -> ShrinkWrap.create(JavaArchive.class) .addClasses(MyAiService.class, OKGuardrail.class, KOGuardrail.class, - MyChatModel.class, MyChatModelSupplier.class, MyMemoryProviderSupplier.class)); + MyChatModel.class, MyChatModelSupplier.class, MyMemoryProviderSupplier.class, + ValidationException.class)); @Inject MyAiService aiService; @@ -73,10 +75,10 @@ void testThatInputGuardrailsAreInvoked() { void testThatGuardrailCanThrowValidationException() { assertThat(koGuardrail.spy()).isEqualTo(0); assertThatThrownBy(() -> aiService.ko("1")) - .hasCauseExactlyInstanceOf(InputGuardrail.ValidationException.class); + .hasCauseExactlyInstanceOf(ValidationException.class); assertThat(koGuardrail.spy()).isEqualTo(1); assertThatThrownBy(() -> aiService.ko("1")) - .hasCauseExactlyInstanceOf(InputGuardrail.ValidationException.class); + .hasCauseExactlyInstanceOf(ValidationException.class); assertThat(koGuardrail.spy()).isEqualTo(2); } @@ -99,8 +101,9 @@ public static class OKGuardrail implements InputGuardrail { AtomicInteger spy = new AtomicInteger(0); @Override - public void validate(dev.langchain4j.data.message.UserMessage um) { + public InputGuardrailResult validate(dev.langchain4j.data.message.UserMessage um) { spy.incrementAndGet(); + return success(); } public int spy() { @@ -114,9 +117,9 @@ public static class KOGuardrail implements InputGuardrail { AtomicInteger spy = new AtomicInteger(0); @Override - public void validate(dev.langchain4j.data.message.UserMessage um) throws ValidationException { + public InputGuardrailResult validate(dev.langchain4j.data.message.UserMessage um) { spy.incrementAndGet(); - throw new ValidationException("KO"); + return failure("KO", new ValidationException("KO")); } public int spy() { diff --git a/core/deployment/src/test/java/io/quarkiverse/langchain4j/test/guardrails/InputGuardrailValidationTest.java b/core/deployment/src/test/java/io/quarkiverse/langchain4j/test/guardrails/InputGuardrailValidationTest.java index 5379f0a40..9d1c841e7 100644 --- a/core/deployment/src/test/java/io/quarkiverse/langchain4j/test/guardrails/InputGuardrailValidationTest.java +++ b/core/deployment/src/test/java/io/quarkiverse/langchain4j/test/guardrails/InputGuardrailValidationTest.java @@ -29,9 +29,10 @@ import dev.langchain4j.service.MemoryId; import dev.langchain4j.service.UserMessage; import io.quarkiverse.langchain4j.RegisterAiService; -import io.quarkiverse.langchain4j.guardrails.GuardrailException; import io.quarkiverse.langchain4j.guardrails.InputGuardrail; +import io.quarkiverse.langchain4j.guardrails.InputGuardrailResult; import io.quarkiverse.langchain4j.guardrails.InputGuardrails; +import io.quarkiverse.langchain4j.runtime.aiservice.GuardrailException; import io.quarkus.test.QuarkusUnitTest; import io.smallrye.mutiny.Multi; @@ -42,7 +43,7 @@ public class InputGuardrailValidationTest { .setArchiveProducer(() -> ShrinkWrap.create(JavaArchive.class) .addClasses(MyAiService.class, OKGuardrail.class, KOGuardrail.class, MyChatModel.class, MyStreamingChatModel.class, MyChatModelSupplier.class, - MyMemoryProviderSupplier.class)); + MyMemoryProviderSupplier.class, ValidationException.class)); @Inject MyAiService aiService; @@ -130,8 +131,9 @@ public static class OKGuardrail implements InputGuardrail { AtomicInteger spy = new AtomicInteger(0); @Override - public void validate(dev.langchain4j.data.message.UserMessage um) { + public InputGuardrailResult validate(dev.langchain4j.data.message.UserMessage um) { spy.incrementAndGet(); + return success(); } public int spy() { @@ -145,9 +147,9 @@ public static class KOGuardrail implements InputGuardrail { AtomicInteger spy = new AtomicInteger(0); @Override - public void validate(dev.langchain4j.data.message.UserMessage um) throws ValidationException { + public InputGuardrailResult validate(dev.langchain4j.data.message.UserMessage um) { spy.incrementAndGet(); - throw new ValidationException("KO"); + return failure("KO"); } public int spy() { @@ -161,7 +163,7 @@ public static class KOFatalGuardrail implements InputGuardrail { AtomicInteger spy = new AtomicInteger(0); @Override - public void validate(dev.langchain4j.data.message.UserMessage um) throws ValidationException { + public InputGuardrailResult validate(dev.langchain4j.data.message.UserMessage um) { spy.incrementAndGet(); throw new IllegalArgumentException("Fatal"); } @@ -177,7 +179,7 @@ public static class MemoryCheck implements InputGuardrail { AtomicInteger spy = new AtomicInteger(0); @Override - public void validate(InputGuardrail.InputGuardrailParams params) { + public InputGuardrailResult validate(InputGuardrail.InputGuardrailParams params) { spy.incrementAndGet(); if (params.memory().messages().isEmpty()) { assertThat(params.userMessage().singleText()).isEqualTo("foo"); @@ -187,6 +189,7 @@ public void validate(InputGuardrail.InputGuardrailParams params) { assertThat(params.memory().messages().get(1).text()).isEqualTo("Hi!"); assertThat(params.userMessage().singleText()).isEqualTo("bar"); } + return success(); } public int spy() { diff --git a/core/deployment/src/test/java/io/quarkiverse/langchain4j/test/guardrails/OutputGuardrailChainTest.java b/core/deployment/src/test/java/io/quarkiverse/langchain4j/test/guardrails/OutputGuardrailChainTest.java index bb26ed305..43c809a0d 100644 --- a/core/deployment/src/test/java/io/quarkiverse/langchain4j/test/guardrails/OutputGuardrailChainTest.java +++ b/core/deployment/src/test/java/io/quarkiverse/langchain4j/test/guardrails/OutputGuardrailChainTest.java @@ -26,6 +26,7 @@ import dev.langchain4j.service.UserMessage; import io.quarkiverse.langchain4j.RegisterAiService; import io.quarkiverse.langchain4j.guardrails.OutputGuardrail; +import io.quarkiverse.langchain4j.guardrails.OutputGuardrailResult; import io.quarkiverse.langchain4j.guardrails.OutputGuardrails; import io.quarkiverse.langchain4j.runtime.aiservice.NoopChatMemory; import io.quarkus.test.QuarkusUnitTest; @@ -98,7 +99,7 @@ public static class FirstGuardrail implements OutputGuardrail { AtomicLong lastAccess = new AtomicLong(); @Override - public void validate(AiMessage responseFromLLM) { + public OutputGuardrailResult validate(AiMessage responseFromLLM) { spy.incrementAndGet(); lastAccess.set(System.nanoTime()); try { @@ -106,6 +107,7 @@ public void validate(AiMessage responseFromLLM) { } catch (InterruptedException e) { // Ignore me } + return success(); } public int spy() { @@ -124,7 +126,7 @@ public static class SecondGuardrail implements OutputGuardrail { volatile AtomicLong lastAccess = new AtomicLong(); @Override - public void validate(AiMessage responseFromLLM) { + public OutputGuardrailResult validate(AiMessage responseFromLLM) { spy.incrementAndGet(); lastAccess.set(System.nanoTime()); try { @@ -132,6 +134,7 @@ public void validate(AiMessage responseFromLLM) { } catch (InterruptedException e) { // Ignore me } + return success(); } public int spy() { @@ -149,10 +152,11 @@ public static class FailingGuardrail implements OutputGuardrail { AtomicInteger spy = new AtomicInteger(0); @Override - public void validate(AiMessage responseFromLLM) throws ValidationException { + public OutputGuardrailResult validate(AiMessage responseFromLLM) { if (spy.incrementAndGet() == 1) { - throw new ValidationException("Retry", true, "Retry"); + return reprompt("Retry", "Retry"); } + return success(); } public int spy() { diff --git a/core/deployment/src/test/java/io/quarkiverse/langchain4j/test/guardrails/OutputGuardrailNotFoundTest.java b/core/deployment/src/test/java/io/quarkiverse/langchain4j/test/guardrails/OutputGuardrailNotFoundTest.java index 3e573a3ed..636d3a8aa 100644 --- a/core/deployment/src/test/java/io/quarkiverse/langchain4j/test/guardrails/OutputGuardrailNotFoundTest.java +++ b/core/deployment/src/test/java/io/quarkiverse/langchain4j/test/guardrails/OutputGuardrailNotFoundTest.java @@ -24,6 +24,7 @@ import dev.langchain4j.service.UserMessage; import io.quarkiverse.langchain4j.RegisterAiService; import io.quarkiverse.langchain4j.guardrails.OutputGuardrail; +import io.quarkiverse.langchain4j.guardrails.OutputGuardrailResult; import io.quarkiverse.langchain4j.guardrails.OutputGuardrails; import io.quarkiverse.langchain4j.runtime.aiservice.NoopChatMemory; import io.quarkus.test.QuarkusUnitTest; @@ -59,7 +60,7 @@ public interface MyAiService { public static class MissingGuardRail implements OutputGuardrail { @Override - public void validate(AiMessage responseFromLLM) { + public OutputGuardrailResult validate(AiMessage responseFromLLM) { throw new RuntimeException("Should not be invoked"); } diff --git a/core/deployment/src/test/java/io/quarkiverse/langchain4j/test/guardrails/OutputGuardrailOnClassAndMethodTest.java b/core/deployment/src/test/java/io/quarkiverse/langchain4j/test/guardrails/OutputGuardrailOnClassAndMethodTest.java index 40c75db76..8434744a2 100644 --- a/core/deployment/src/test/java/io/quarkiverse/langchain4j/test/guardrails/OutputGuardrailOnClassAndMethodTest.java +++ b/core/deployment/src/test/java/io/quarkiverse/langchain4j/test/guardrails/OutputGuardrailOnClassAndMethodTest.java @@ -25,6 +25,7 @@ import dev.langchain4j.service.UserMessage; import io.quarkiverse.langchain4j.RegisterAiService; import io.quarkiverse.langchain4j.guardrails.OutputGuardrail; +import io.quarkiverse.langchain4j.guardrails.OutputGuardrailResult; import io.quarkiverse.langchain4j.guardrails.OutputGuardrails; import io.quarkiverse.langchain4j.runtime.aiservice.NoopChatMemory; import io.quarkus.test.QuarkusUnitTest; @@ -74,8 +75,9 @@ public static class OKGuardrail implements OutputGuardrail { AtomicInteger spy = new AtomicInteger(0); @Override - public void validate(AiMessage responseFromLLM) { + public OutputGuardrailResult validate(AiMessage responseFromLLM) { spy.incrementAndGet(); + return success(); } public int spy() { @@ -89,9 +91,9 @@ public static class KOGuardrail implements OutputGuardrail { AtomicInteger spy = new AtomicInteger(0); @Override - public void validate(AiMessage responseFromLLM) throws ValidationException { + public OutputGuardrailResult validate(AiMessage responseFromLLM) { spy.incrementAndGet(); - throw new ValidationException("KO", false, null); + return failure("KO"); } public int spy() { diff --git a/core/deployment/src/test/java/io/quarkiverse/langchain4j/test/guardrails/OutputGuardrailOnClassTest.java b/core/deployment/src/test/java/io/quarkiverse/langchain4j/test/guardrails/OutputGuardrailOnClassTest.java index e1177632f..1e2e4afed 100644 --- a/core/deployment/src/test/java/io/quarkiverse/langchain4j/test/guardrails/OutputGuardrailOnClassTest.java +++ b/core/deployment/src/test/java/io/quarkiverse/langchain4j/test/guardrails/OutputGuardrailOnClassTest.java @@ -25,6 +25,7 @@ import dev.langchain4j.service.UserMessage; import io.quarkiverse.langchain4j.RegisterAiService; import io.quarkiverse.langchain4j.guardrails.OutputGuardrail; +import io.quarkiverse.langchain4j.guardrails.OutputGuardrailResult; import io.quarkiverse.langchain4j.guardrails.OutputGuardrails; import io.quarkiverse.langchain4j.runtime.aiservice.NoopChatMemory; import io.quarkus.test.QuarkusUnitTest; @@ -68,8 +69,9 @@ public static class OKGuardrail implements OutputGuardrail { AtomicInteger spy = new AtomicInteger(0); @Override - public void validate(AiMessage responseFromLLM) { + public OutputGuardrailResult validate(AiMessage responseFromLLM) { spy.incrementAndGet(); + return success(); } public int spy() { diff --git a/core/deployment/src/test/java/io/quarkiverse/langchain4j/test/guardrails/OutputGuardrailRepromptingRetryDisabledTest.java b/core/deployment/src/test/java/io/quarkiverse/langchain4j/test/guardrails/OutputGuardrailRepromptingRetryDisabledTest.java index e0af8ae2b..fcfe685a2 100644 --- a/core/deployment/src/test/java/io/quarkiverse/langchain4j/test/guardrails/OutputGuardrailRepromptingRetryDisabledTest.java +++ b/core/deployment/src/test/java/io/quarkiverse/langchain4j/test/guardrails/OutputGuardrailRepromptingRetryDisabledTest.java @@ -28,9 +28,10 @@ import dev.langchain4j.service.MemoryId; import dev.langchain4j.service.SystemMessage; import io.quarkiverse.langchain4j.RegisterAiService; -import io.quarkiverse.langchain4j.guardrails.GuardrailException; import io.quarkiverse.langchain4j.guardrails.OutputGuardrail; +import io.quarkiverse.langchain4j.guardrails.OutputGuardrailResult; import io.quarkiverse.langchain4j.guardrails.OutputGuardrails; +import io.quarkiverse.langchain4j.runtime.aiservice.GuardrailException; import io.quarkus.test.QuarkusUnitTest; public class OutputGuardrailRepromptingRetryDisabledTest { @@ -101,8 +102,8 @@ public ChatLanguageModel get() { public static class OkGuardrail implements OutputGuardrail { @Override - public void validate(AiMessage responseFromLLM) { - // OK + public OutputGuardrailResult validate(AiMessage responseFromLLM) { + return success(); } } @@ -113,9 +114,9 @@ public static class RetryGuardrail implements OutputGuardrail { private final AtomicInteger spy = new AtomicInteger(0); @Override - public void validate(OutputGuardrailParams params) throws ValidationException { + public OutputGuardrailResult validate(OutputGuardrailParams params) { int v = spy.incrementAndGet(); - throw new ValidationException("Retry", true, null); + return retry("Retry"); } public int getSpy() { @@ -129,9 +130,9 @@ public static class RepromptingGuardrail implements OutputGuardrail { private final AtomicInteger spy = new AtomicInteger(0); @Override - public void validate(OutputGuardrailParams params) throws ValidationException { + public OutputGuardrailResult validate(OutputGuardrailParams params) { int v = spy.incrementAndGet(); - throw new ValidationException("Retry", true, "reprompt"); + return reprompt("Retry", "reprompt"); } public int getSpy() { diff --git a/core/deployment/src/test/java/io/quarkiverse/langchain4j/test/guardrails/OutputGuardrailRepromptingTest.java b/core/deployment/src/test/java/io/quarkiverse/langchain4j/test/guardrails/OutputGuardrailRepromptingTest.java index ab8c38521..54a8b782e 100644 --- a/core/deployment/src/test/java/io/quarkiverse/langchain4j/test/guardrails/OutputGuardrailRepromptingTest.java +++ b/core/deployment/src/test/java/io/quarkiverse/langchain4j/test/guardrails/OutputGuardrailRepromptingTest.java @@ -29,9 +29,10 @@ import dev.langchain4j.service.MemoryId; import dev.langchain4j.service.SystemMessage; import io.quarkiverse.langchain4j.RegisterAiService; -import io.quarkiverse.langchain4j.guardrails.GuardrailException; import io.quarkiverse.langchain4j.guardrails.OutputGuardrail; +import io.quarkiverse.langchain4j.guardrails.OutputGuardrailResult; import io.quarkiverse.langchain4j.guardrails.OutputGuardrails; +import io.quarkiverse.langchain4j.runtime.aiservice.GuardrailException; import io.quarkus.test.QuarkusUnitTest; public class OutputGuardrailRepromptingTest { @@ -103,11 +104,11 @@ public static class RepromptingOne implements OutputGuardrail { private final AtomicInteger spy = new AtomicInteger(0); @Override - public void validate(AiMessage responseFromLLM) throws ValidationException { + public OutputGuardrailResult validate(AiMessage responseFromLLM) { if (spy.incrementAndGet() == 1) { - throw new ValidationException("Retry", true, "Retry"); + return reprompt("Retry", "Retry"); } - // OK + return success(); } public int getSpy() { @@ -121,14 +122,14 @@ public static class RepromptingTwo implements OutputGuardrail { private final AtomicInteger spy = new AtomicInteger(0); @Override - public void validate(OutputGuardrailParams params) throws ValidationException { + public OutputGuardrailResult validate(OutputGuardrailParams params) { int v = spy.incrementAndGet(); List messages = params.memory().messages(); if (v == 1) { ChatMessage last = messages.get(messages.size() - 1); assertThat(last).isInstanceOf(AiMessage.class); assertThat(((AiMessage) last).text()).isEqualTo("Nope"); - throw new ValidationException("Retry", true, "Retry"); + return reprompt("Retry", "Retry"); } if (v == 2) { // Check that it's in memory @@ -140,12 +141,12 @@ public void validate(OutputGuardrailParams params) throws ValidationException { assertThat(beforeLast).isInstanceOf(UserMessage.class); assertThat(beforeLast.text()).isEqualTo("Retry"); - throw new ValidationException("Retry", true, "Retry"); + return reprompt("Retry", "Retry"); } if (v != 3) { throw new IllegalArgumentException("Unexpected call"); } - // OK + return success(); } public int getSpy() { @@ -159,14 +160,14 @@ public static class RepromptingFailed implements OutputGuardrail { private final AtomicInteger spy = new AtomicInteger(0); @Override - public void validate(OutputGuardrailParams params) throws ValidationException { + public OutputGuardrailResult validate(OutputGuardrailParams params) { int v = spy.incrementAndGet(); List messages = params.memory().messages(); if (v == 1) { ChatMessage last = messages.get(messages.size() - 1); assertThat(last).isInstanceOf(AiMessage.class); assertThat(((AiMessage) last).text()).isEqualTo("Nope"); - throw new ValidationException("Retry", true, "Retry Once"); + return reprompt("Retry", "Retry Once"); } if (v == 2) { // Check that it's in memory @@ -177,9 +178,9 @@ public void validate(OutputGuardrailParams params) throws ValidationException { assertThat(((AiMessage) last).text()).isEqualTo("Hello"); assertThat(beforeLast).isInstanceOf(UserMessage.class); assertThat(beforeLast.text()).isEqualTo("Retry Once"); - throw new ValidationException("Retry", true, "Retry Twice"); + return reprompt("Retry", "Retry Twice"); } - throw new ValidationException("Retry", true, "Retry Again"); + return reprompt("Retry", "Retry Again"); } public int getSpy() { diff --git a/core/deployment/src/test/java/io/quarkiverse/langchain4j/test/guardrails/OutputGuardrailTest.java b/core/deployment/src/test/java/io/quarkiverse/langchain4j/test/guardrails/OutputGuardrailTest.java index 8a6b0435b..0331936ff 100644 --- a/core/deployment/src/test/java/io/quarkiverse/langchain4j/test/guardrails/OutputGuardrailTest.java +++ b/core/deployment/src/test/java/io/quarkiverse/langchain4j/test/guardrails/OutputGuardrailTest.java @@ -27,6 +27,7 @@ import dev.langchain4j.service.UserMessage; import io.quarkiverse.langchain4j.RegisterAiService; import io.quarkiverse.langchain4j.guardrails.OutputGuardrail; +import io.quarkiverse.langchain4j.guardrails.OutputGuardrailResult; import io.quarkiverse.langchain4j.guardrails.OutputGuardrails; import io.quarkiverse.langchain4j.runtime.aiservice.NoopChatMemory; import io.quarkus.arc.Arc; @@ -37,8 +38,8 @@ public class OutputGuardrailTest { @RegisterExtension static final QuarkusUnitTest unitTest = new QuarkusUnitTest() .setArchiveProducer(() -> ShrinkWrap.create(JavaArchive.class) - .addClasses(MyAiService.class, OKGuardrail.class, KOGuardrail.class, - MyChatModel.class, MyChatModelSupplier.class, MyMemoryProviderSupplier.class)); + .addClasses(MyAiService.class, OKGuardrail.class, KOGuardrail.class, MyChatModel.class, + MyChatModelSupplier.class, MyMemoryProviderSupplier.class, ValidationException.class)); @Inject MyAiService aiService; @@ -73,10 +74,10 @@ void testThatOutputGuardrailsAreInvoked() { void testThatGuardrailCanThrowValidationException() { assertThat(koGuardrail.spy()).isEqualTo(0); assertThatThrownBy(() -> aiService.ko("1")) - .hasCauseExactlyInstanceOf(OutputGuardrail.ValidationException.class); + .hasCauseExactlyInstanceOf(ValidationException.class); assertThat(koGuardrail.spy()).isEqualTo(1); assertThatThrownBy(() -> aiService.ko("1")) - .hasCauseExactlyInstanceOf(OutputGuardrail.ValidationException.class); + .hasCauseExactlyInstanceOf(ValidationException.class); assertThat(koGuardrail.spy()).isEqualTo(2); } @@ -99,8 +100,9 @@ public static class OKGuardrail implements OutputGuardrail { AtomicInteger spy = new AtomicInteger(0); @Override - public void validate(AiMessage responseFromLLM) { + public OutputGuardrailResult validate(AiMessage responseFromLLM) { spy.incrementAndGet(); + return success(); } public int spy() { @@ -114,9 +116,9 @@ public static class KOGuardrail implements OutputGuardrail { AtomicInteger spy = new AtomicInteger(0); @Override - public void validate(AiMessage responseFromLLM) throws ValidationException { + public OutputGuardrailResult validate(AiMessage responseFromLLM) { spy.incrementAndGet(); - throw new ValidationException("KO", false, null); + return failure("KO", new ValidationException("KO")); } public int spy() { diff --git a/core/deployment/src/test/java/io/quarkiverse/langchain4j/test/guardrails/OutputGuardrailValidationTest.java b/core/deployment/src/test/java/io/quarkiverse/langchain4j/test/guardrails/OutputGuardrailValidationTest.java index 2a35c7faa..95127a5e1 100644 --- a/core/deployment/src/test/java/io/quarkiverse/langchain4j/test/guardrails/OutputGuardrailValidationTest.java +++ b/core/deployment/src/test/java/io/quarkiverse/langchain4j/test/guardrails/OutputGuardrailValidationTest.java @@ -26,9 +26,10 @@ import dev.langchain4j.service.MemoryId; import dev.langchain4j.service.UserMessage; import io.quarkiverse.langchain4j.RegisterAiService; -import io.quarkiverse.langchain4j.guardrails.GuardrailException; import io.quarkiverse.langchain4j.guardrails.OutputGuardrail; +import io.quarkiverse.langchain4j.guardrails.OutputGuardrailResult; import io.quarkiverse.langchain4j.guardrails.OutputGuardrails; +import io.quarkiverse.langchain4j.runtime.aiservice.GuardrailException; import io.quarkiverse.langchain4j.runtime.aiservice.NoopChatMemory; import io.quarkus.test.QuarkusUnitTest; @@ -91,14 +92,6 @@ void testFatalException() { assertThat(fatal.spy()).isEqualTo(1); } - @Test - @ActivateRequestContext - void testRepromptingWithoutRetry() { - assertThatThrownBy(() -> aiService.repromptingWithoutRetry("6")) - .isInstanceOf(GuardrailException.class) - .hasCauseExactlyInstanceOf(IllegalArgumentException.class); - } - @RegisterAiService(chatLanguageModelSupplier = MyChatModelSupplier.class, chatMemoryProviderSupplier = MyMemoryProviderSupplier.class) public interface MyAiService { @@ -121,10 +114,6 @@ public interface MyAiService { @UserMessage("Say Hi!") @OutputGuardrails(KOFatalGuardrail.class) String fatal(@MemoryId String mem); - - @UserMessage("Say Hi!") - @OutputGuardrails(RepromptingWithoutRetryGuardrail.class) - String repromptingWithoutRetry(@MemoryId String mem); } @RequestScoped @@ -133,8 +122,9 @@ public static class OKGuardrail implements OutputGuardrail { AtomicInteger spy = new AtomicInteger(0); @Override - public void validate(AiMessage responseFromLLM) { + public OutputGuardrailResult validate(AiMessage responseFromLLM) { spy.incrementAndGet(); + return success(); } public int spy() { @@ -148,9 +138,9 @@ public static class KOGuardrail implements OutputGuardrail { AtomicInteger spy = new AtomicInteger(0); @Override - public void validate(AiMessage responseFromLLM) throws ValidationException { + public OutputGuardrailResult validate(AiMessage responseFromLLM) { spy.incrementAndGet(); - throw new ValidationException("KO", false, null); + return failure("KO"); } public int spy() { @@ -164,12 +154,12 @@ public static class RetryingGuardrail implements OutputGuardrail { AtomicInteger spy = new AtomicInteger(0); @Override - public void validate(AiMessage responseFromLLM) throws ValidationException { + public OutputGuardrailResult validate(AiMessage responseFromLLM) { int v = spy.incrementAndGet(); if (v == 2) { - return; + return OutputGuardrailResult.success(); } - throw new ValidationException("KO", true, null); + return retry("KO"); } public int spy() { @@ -183,9 +173,9 @@ public static class RetryingButFailGuardrail implements OutputGuardrail { AtomicInteger spy = new AtomicInteger(0); @Override - public void validate(AiMessage responseFromLLM) throws ValidationException { + public OutputGuardrailResult validate(AiMessage responseFromLLM) { int v = spy.incrementAndGet(); - throw new ValidationException("KO", true, null); + return retry("KO"); } public int spy() { @@ -193,23 +183,13 @@ public int spy() { } } - @ApplicationScoped - public static class RepromptingWithoutRetryGuardrail implements OutputGuardrail { - - @Override - public void validate(AiMessage responseFromLLM) throws ValidationException { - throw new ValidationException("KO", false, "Reprompt"); - } - - } - @ApplicationScoped public static class KOFatalGuardrail implements OutputGuardrail { AtomicInteger spy = new AtomicInteger(0); @Override - public void validate(AiMessage responseFromLLM) throws ValidationException { + public OutputGuardrailResult validate(AiMessage responseFromLLM) { spy.incrementAndGet(); throw new IllegalArgumentException("Fatal"); } diff --git a/core/deployment/src/test/java/io/quarkiverse/langchain4j/test/guardrails/ValidationException.java b/core/deployment/src/test/java/io/quarkiverse/langchain4j/test/guardrails/ValidationException.java new file mode 100644 index 000000000..88b9aa5d7 --- /dev/null +++ b/core/deployment/src/test/java/io/quarkiverse/langchain4j/test/guardrails/ValidationException.java @@ -0,0 +1,7 @@ +package io.quarkiverse.langchain4j.test.guardrails; + +public class ValidationException extends RuntimeException { + public ValidationException(String message) { + super(message); + } +} \ No newline at end of file diff --git a/core/runtime/src/main/java/io/quarkiverse/langchain4j/guardrails/Guardrail.java b/core/runtime/src/main/java/io/quarkiverse/langchain4j/guardrails/Guardrail.java new file mode 100644 index 000000000..06bf0d48b --- /dev/null +++ b/core/runtime/src/main/java/io/quarkiverse/langchain4j/guardrails/Guardrail.java @@ -0,0 +1,16 @@ +package io.quarkiverse.langchain4j.guardrails; + +/** + * A guardrail is a rule that is applied when interacting with an LLM either to the input (the user message) or to the output of + * the model to ensure that they are safe and meet the expectations of the model. + */ +public interface Guardrail

> { + + /** + * Validate the interaction between the model and the user in one of the two directions. + * + * @param params The parameters of the request or the response to be validated. + * @return The result of this validation. + */ + R validate(P params); +} diff --git a/core/runtime/src/main/java/io/quarkiverse/langchain4j/guardrails/GuardrailParams.java b/core/runtime/src/main/java/io/quarkiverse/langchain4j/guardrails/GuardrailParams.java new file mode 100644 index 000000000..6706efbb3 --- /dev/null +++ b/core/runtime/src/main/java/io/quarkiverse/langchain4j/guardrails/GuardrailParams.java @@ -0,0 +1,21 @@ +package io.quarkiverse.langchain4j.guardrails; + +import dev.langchain4j.memory.ChatMemory; +import dev.langchain4j.rag.AugmentationResult; + +/** + * Represents the parameter passed to {@link Guardrail#validate(GuardrailParams)}} in order to validate an interaction between a + * user and the LLM. + */ +public interface GuardrailParams { + + /** + * @return the memory, can be {@code null} or empty + */ + ChatMemory memory(); + + /** + * @return the augmentation result, can be {@code null} + */ + AugmentationResult augmentationResult(); +} diff --git a/core/runtime/src/main/java/io/quarkiverse/langchain4j/guardrails/GuardrailResult.java b/core/runtime/src/main/java/io/quarkiverse/langchain4j/guardrails/GuardrailResult.java new file mode 100644 index 000000000..1246f04aa --- /dev/null +++ b/core/runtime/src/main/java/io/quarkiverse/langchain4j/guardrails/GuardrailResult.java @@ -0,0 +1,60 @@ +package io.quarkiverse.langchain4j.guardrails; + +import java.util.List; + +/** + * The result of the validation of an interaction between a user and the LLM. + */ +public interface GuardrailResult { + + /** + * The possible results of a guardrails validation. + */ + enum Result { + /** + * A successful validation. + */ + SUCCESS, + /** + * A failed validation not preventing the subsequent validations eventually registered to be evaluated. + */ + FAILURE, + /** + * A fatal failed validation, blocking the evaluation of any other validations eventually registered. + */ + FATAL + } + + boolean isSuccess(); + + boolean isFatal(); + + /** + * @return The list of failures eventually resulting from a set of validations. + */ + List failures(); + + default Throwable getFirstFailureException() { + if (!isSuccess()) { + for (Failure failure : failures()) { + if (failure.cause() != null) { + return failure.cause(); + } + } + } + return null; + } + + GR validatedBy(Class guardrailClass); + + /** + * The message and the cause of the failure of a single validation. + */ + interface Failure { + Failure withGuardrailClass(Class guardrailClass); + + String message(); + + Throwable cause(); + } +} diff --git a/core/runtime/src/main/java/io/quarkiverse/langchain4j/guardrails/InputGuardrail.java b/core/runtime/src/main/java/io/quarkiverse/langchain4j/guardrails/InputGuardrail.java index bf608313c..4090486d0 100644 --- a/core/runtime/src/main/java/io/quarkiverse/langchain4j/guardrails/InputGuardrail.java +++ b/core/runtime/src/main/java/io/quarkiverse/langchain4j/guardrails/InputGuardrail.java @@ -1,5 +1,7 @@ package io.quarkiverse.langchain4j.guardrails; +import java.util.Arrays; + import dev.langchain4j.data.message.UserMessage; import dev.langchain4j.memory.ChatMemory; import dev.langchain4j.rag.AugmentationResult; @@ -10,46 +12,18 @@ * safe and meets the expectations of the model. *

* Implementation should be exposed as a CDI bean, and the class name configured in {@link InputGuardrails#value()} annotation. - *

- * Implementation should throw a {@link ValidationException} when the validation fails. */ @Experimental("This feature is experimental and the API is subject to change") -public interface InputGuardrail { - - /** - * An exception thrown when the validation fails. - */ - class ValidationException extends Exception { - - /** - * Creates a new instance of {@link ValidationException} without a cause. - * - * @param message the error message - */ - public ValidationException(String message) { - super(message); - } - - /** - * Creates a new instance of {@link ValidationException} with a cause. - * - * @param message the error message - * @param cause the cause - */ - public ValidationException(String message, Throwable cause) { - super(message, cause); - } - } +public interface InputGuardrail extends Guardrail { /** * Validates the {@code user message} that will be sent to the LLM. *

* * @param userMessage the response from the LLM - * @throws ValidationException the exception throws if the validation fails. */ - default void validate(UserMessage userMessage) throws ValidationException { - throw new ValidationException("Validation not implemented"); + default InputGuardrailResult validate(UserMessage userMessage) { + return failure("Validation not implemented"); } /** @@ -62,11 +36,10 @@ default void validate(UserMessage userMessage) throws ValidationException { * * @param params the parameters, including the user message, the memory (maybe null), * and the augmentation result (maybe null). Cannot be {@code null} - * @throws ValidationException the exception throws if the validation fails. */ - default void validate(InputGuardrailParams params) - throws ValidationException { - validate(params.userMessage()); + @Override + default InputGuardrailResult validate(InputGuardrailParams params) { + return validate(params.userMessage()); } /** @@ -76,7 +49,50 @@ default void validate(InputGuardrailParams params) * @param memory the memory, can be {@code null} or empty * @param augmentationResult the augmentation result, can be {@code null} */ - record InputGuardrailParams(UserMessage userMessage, ChatMemory memory, AugmentationResult augmentationResult) { + record InputGuardrailParams(UserMessage userMessage, ChatMemory memory, + AugmentationResult augmentationResult) implements GuardrailParams { } + /** + * @return The result of a successful input guardrail validation. + */ + default InputGuardrailResult success() { + return InputGuardrailResult.success(); + } + + /** + * @param message A message describing the failure. + * @return The result of a failed input guardrail validation. + */ + default InputGuardrailResult failure(String message) { + return new InputGuardrailResult(Arrays.asList(new InputGuardrailResult.Failure(message)), false); + } + + /** + * @param message A message describing the failure. + * @param cause The exception that caused this failure. + * @return The result of a failed input guardrail validation. + */ + default InputGuardrailResult failure(String message, Throwable cause) { + return new InputGuardrailResult(Arrays.asList(new InputGuardrailResult.Failure(message, cause)), false); + } + + /** + * @param message A message describing the failure. + * @return The result of a fatally failed input guardrail validation, blocking the evaluation of any other subsequent + * validation. + */ + default InputGuardrailResult fatal(String message) { + return new InputGuardrailResult(Arrays.asList(new InputGuardrailResult.Failure(message)), true); + } + + /** + * @param message A message describing the failure. + * @param cause The exception that caused this failure. + * @return The result of a fatally failed input guardrail validation, blocking the evaluation of any other subsequent + * validation. + */ + default InputGuardrailResult fatal(String message, Throwable cause) { + return new InputGuardrailResult(Arrays.asList(new InputGuardrailResult.Failure(message, cause)), true); + } } diff --git a/core/runtime/src/main/java/io/quarkiverse/langchain4j/guardrails/InputGuardrailResult.java b/core/runtime/src/main/java/io/quarkiverse/langchain4j/guardrails/InputGuardrailResult.java new file mode 100644 index 000000000..8c56b7953 --- /dev/null +++ b/core/runtime/src/main/java/io/quarkiverse/langchain4j/guardrails/InputGuardrailResult.java @@ -0,0 +1,81 @@ +package io.quarkiverse.langchain4j.guardrails; + +import java.util.Collections; +import java.util.List; +import java.util.stream.Collectors; + +/** + * The result of the validation of an {@link InputGuardrail} + * + * @param result The result of the input guardrail validation. + * @param failures The list of failures, empty if the validation succeeded. + */ +public record InputGuardrailResult(Result result, List failures) implements GuardrailResult { + + private static final InputGuardrailResult SUCCESS = new InputGuardrailResult(); + + private InputGuardrailResult() { + this(Result.SUCCESS, Collections.emptyList()); + } + + InputGuardrailResult(List failures, boolean fatal) { + this(fatal ? Result.FATAL : Result.FAILURE, failures); + } + + public static InputGuardrailResult success() { + return InputGuardrailResult.SUCCESS; + } + + public static InputGuardrailResult failure(List failures) { + return new InputGuardrailResult((List) failures, false); + } + + @Override + public boolean isSuccess() { + return result == Result.SUCCESS; + } + + @Override + public boolean isFatal() { + return result == Result.FATAL; + } + + @Override + public InputGuardrailResult validatedBy(Class guardrailClass) { + if (!isSuccess()) { + if (failures.size() != 1) { + throw new IllegalArgumentException(); + } + failures.set(0, failures.get(0).withGuardrailClass(guardrailClass)); + } + return this; + } + + @Override + public String toString() { + if (isSuccess()) { + return "success"; + } + return failures.stream().map(Failure::toString).collect(Collectors.joining(", ")); + } + + record Failure(String message, Throwable cause, + Class guardrailClass) implements GuardrailResult.Failure { + public Failure(String message) { + this(message, null); + } + + public Failure(String message, Throwable cause) { + this(message, cause, null); + } + + public Failure withGuardrailClass(Class guardrailClass) { + return new Failure(message, cause, guardrailClass); + } + + @Override + public String toString() { + return "The guardrail " + guardrailClass.getName() + " failed with this message: " + message; + } + } +} diff --git a/core/runtime/src/main/java/io/quarkiverse/langchain4j/guardrails/InputGuardrails.java b/core/runtime/src/main/java/io/quarkiverse/langchain4j/guardrails/InputGuardrails.java index 3e81134e9..2269675eb 100644 --- a/core/runtime/src/main/java/io/quarkiverse/langchain4j/guardrails/InputGuardrails.java +++ b/core/runtime/src/main/java/io/quarkiverse/langchain4j/guardrails/InputGuardrails.java @@ -6,6 +6,8 @@ import java.lang.annotation.Retention; import java.lang.annotation.Target; +import io.quarkiverse.langchain4j.runtime.aiservice.GuardrailException; + /** * An annotation to apply guardrails to the input of the model. *

@@ -15,8 +17,6 @@ * the input is safe and meets the expectations of the model. * It does not replace moderation model, but it can be used to add additional checks. *

- * When a validation fails, the guardrail throws a {@link InputGuardrail.ValidationException}. - *

* Unlike for output guardrails, the input guardrails do not support retry or reprompt. * The failure is passed directly to the caller, wrapped into a {@link GuardrailException} *

diff --git a/core/runtime/src/main/java/io/quarkiverse/langchain4j/guardrails/OutputGuardrail.java b/core/runtime/src/main/java/io/quarkiverse/langchain4j/guardrails/OutputGuardrail.java index e6b8d7c83..afcd073fc 100644 --- a/core/runtime/src/main/java/io/quarkiverse/langchain4j/guardrails/OutputGuardrail.java +++ b/core/runtime/src/main/java/io/quarkiverse/langchain4j/guardrails/OutputGuardrail.java @@ -1,5 +1,7 @@ package io.quarkiverse.langchain4j.guardrails; +import java.util.Arrays; + import dev.langchain4j.data.message.AiMessage; import dev.langchain4j.memory.ChatMemory; import dev.langchain4j.rag.AugmentationResult; @@ -11,89 +13,20 @@ *

* Implementation should be exposed as a CDI bean, and the class name configured in {@link OutputGuardrails#value()} annotation. *

- * Implementation should throw a {@link ValidationException} when the validation fails. The exception can indicate whether the - * request should be retried and provide a {@code reprompt} message. * In the case of reprompting, the reprompt message is added to the LLM context and the request is retried. *

* The maximum number of retries is configurable using {@code quarkus.langchain4j.guardrails.max-retries}, defaulting to 3. */ @Experimental("This feature is experimental and the API is subject to change") -public interface OutputGuardrail { - - /** - * An exception thrown when the validation fails. - */ - class ValidationException extends Exception { - - private final boolean retry; - private final String reprompt; - - /** - * Creates a new instance of {@link ValidationException} without a cause. - * - * @param message the error message - * @param retry whether the request should be retried - * @param reprompt if the request should be retried, the reprompt message. If null, the original request - * with the same context will be retried. - */ - public ValidationException(String message, boolean retry, String reprompt) { - super(message); - this.retry = retry; - this.reprompt = reprompt; - if (reprompt != null && !retry) { - throw new IllegalArgumentException("Reprompt message is only allowed if retry is true"); - } - } - - /** - * Creates a new instance of {@link ValidationException} with a cause. - * - * @param message the error message - * @param cause the cause - * @param retry whether the request should be retried - * @param reprompt if the request should be retried, the reprompt message. If null, the original request with - * the same context will be retried. - */ - public ValidationException(String message, Throwable cause, boolean retry, String reprompt) { - super(message, cause); - this.retry = retry; - this.reprompt = reprompt; - } - - /** - * Whether the request should be retried. - * - * @return true if the request should be retried, false otherwise. - */ - public boolean isRetry() { - return retry; - } - - /** - * The reprompt message. - *

- * If {@code isRetry()} returns true, the reprompt message (if not {@code null} is added to the LLM context and - * the request is retried. If the reprompt message is {@code null}, the original request with the same context. - *

- * If {@code isRetry()} returns false, the reprompt message is ignored. - * - * @return the reprompt message, or null if the original request with the same context should be retried. - */ - public String getReprompt() { - return reprompt; - } - } +public interface OutputGuardrail extends Guardrail { /** * Validates the response from the LLM. - *

- * If the validation fails with an exception that is not a {@link ValidationException}, no retry will be attempted. * * @param responseFromLLM the response from the LLM - * @throws ValidationException the exception throws if the validation fails. */ - default void validate(AiMessage responseFromLLM) throws ValidationException { - throw new ValidationException("Validation not implemented", false, null); + default OutputGuardrailResult validate(AiMessage responseFromLLM) { + return failure("Validation not implemented"); } /** @@ -102,17 +35,14 @@ default void validate(AiMessage responseFromLLM) throws ValidationException { * Unlike {@link #validate(AiMessage)}, this method allows to access the memory and the augmentation result (in the * case of a RAG). *

- * If the validation fails with an exception that is not a {@link ValidationException}, no retry will be attempted. - *

* Implementation must not attempt to write to the memory or the augmentation result. * * @param params the parameters, including the response from the LLM, the memory (maybe null), * and the augmentation result (maybe null). Cannot be {@code null} - * @throws ValidationException the exception throws if the validation fails. */ - default void validate(OutputGuardrailParams params) - throws ValidationException { - validate(params.responseFromLLM()); + @Override + default OutputGuardrailResult validate(OutputGuardrailParams params) { + return validate(params.responseFromLLM()); } /** @@ -122,7 +52,91 @@ default void validate(OutputGuardrailParams params) * @param memory the memory, can be {@code null} or empty * @param augmentationResult the augmentation result, can be {@code null} */ - record OutputGuardrailParams(AiMessage responseFromLLM, ChatMemory memory, AugmentationResult augmentationResult) { + record OutputGuardrailParams(AiMessage responseFromLLM, ChatMemory memory, + AugmentationResult augmentationResult) implements GuardrailParams { + } + + /** + * @return The result of a successful output guardrail validation. + */ + default OutputGuardrailResult success() { + return OutputGuardrailResult.success(); + } + + /** + * @param message A message describing the failure. + * @return The result of a failed output guardrail validation. + */ + default OutputGuardrailResult failure(String message) { + return new OutputGuardrailResult(Arrays.asList(new OutputGuardrailResult.Failure(message)), false); + } + + /** + * @param message A message describing the failure. + * @param cause The exception that caused this failure. + * @return The result of a failed output guardrail validation. + */ + default OutputGuardrailResult failure(String message, Throwable cause) { + return new OutputGuardrailResult(Arrays.asList(new OutputGuardrailResult.Failure(message, cause)), false); } + /** + * @param message A message describing the failure. + * @return The result of a fatally failed output guardrail validation, blocking the evaluation of any other subsequent + * validation. + */ + default OutputGuardrailResult fatal(String message) { + return new OutputGuardrailResult(Arrays.asList(new OutputGuardrailResult.Failure(message)), true); + } + + /** + * @param message A message describing the failure. + * @param cause The exception that caused this failure. + * @return The result of a fatally failed output guardrail validation, blocking the evaluation of any other subsequent + * validation. + */ + default OutputGuardrailResult fatal(String message, Throwable cause) { + return new OutputGuardrailResult(Arrays.asList(new OutputGuardrailResult.Failure(message, cause)), true); + } + + /** + * @param message A message describing the failure. + * @return The result of a fatally failed output guardrail validation, blocking the evaluation of any other subsequent + * validation and triggering a retry with the same user prompt. + */ + default OutputGuardrailResult retry(String message) { + return new OutputGuardrailResult(Arrays.asList(new OutputGuardrailResult.Failure(message, null, true)), true); + } + + /** + * @param message A message describing the failure. + * @param cause The exception that caused this failure. + * @return The result of a fatally failed output guardrail validation, blocking the evaluation of any other subsequent + * validation and triggering a retry with the same user prompt. + */ + default OutputGuardrailResult retry(String message, Throwable cause) { + return new OutputGuardrailResult(Arrays.asList(new OutputGuardrailResult.Failure(message, cause, true)), true); + } + + /** + * @param message A message describing the failure. + * @param reprompt The new prompt to be used for the retry. + * @return The result of a fatally failed output guardrail validation, blocking the evaluation of any other subsequent + * validation and triggering a retry with a new user prompt. + */ + default OutputGuardrailResult reprompt(String message, String reprompt) { + return new OutputGuardrailResult(Arrays.asList(new OutputGuardrailResult.Failure(message, null, true, reprompt)), true); + } + + /** + * @param message A message describing the failure. + * @param cause The exception that caused this failure. + * @param reprompt The new prompt to be used for the retry. + * @return The result of a fatally failed output guardrail validation, blocking the evaluation of any other subsequent + * validation and triggering a retry with a new user prompt. + */ + default OutputGuardrailResult reprompt(String message, Throwable cause, String reprompt) { + return new OutputGuardrailResult(Arrays.asList(new OutputGuardrailResult.Failure(message, cause, true, reprompt)), + true); + } } diff --git a/core/runtime/src/main/java/io/quarkiverse/langchain4j/guardrails/OutputGuardrailResult.java b/core/runtime/src/main/java/io/quarkiverse/langchain4j/guardrails/OutputGuardrailResult.java new file mode 100644 index 000000000..7c5bbf33d --- /dev/null +++ b/core/runtime/src/main/java/io/quarkiverse/langchain4j/guardrails/OutputGuardrailResult.java @@ -0,0 +1,106 @@ +package io.quarkiverse.langchain4j.guardrails; + +import java.util.Collections; +import java.util.List; +import java.util.stream.Collectors; + +/** + * The result of the validation of an {@link OutputGuardrail} + * + * @param result The result of the output guardrail validation. + * @param failures The list of failures, empty if the validation succeeded. + */ +public record OutputGuardrailResult(Result result, List failures) implements GuardrailResult { + + private static final OutputGuardrailResult SUCCESS = new OutputGuardrailResult(); + + private OutputGuardrailResult() { + this(Result.SUCCESS, Collections.emptyList()); + } + + OutputGuardrailResult(List failures, boolean fatal) { + this(fatal ? Result.FATAL : Result.FAILURE, failures); + } + + public static OutputGuardrailResult success() { + return SUCCESS; + } + + public static OutputGuardrailResult failure(List failures) { + return new OutputGuardrailResult((List) failures, false); + } + + @Override + public boolean isSuccess() { + return result == Result.SUCCESS; + } + + public boolean isRetry() { + return !isSuccess() && failures.stream().anyMatch(Failure::retry); + } + + public String getReprompt() { + if (!isSuccess()) { + for (Failure failure : failures) { + if (failure.reprompt() != null) { + return failure.reprompt(); + } + } + } + return null; + } + + @Override + public boolean isFatal() { + return result == Result.FATAL; + } + + @Override + public OutputGuardrailResult validatedBy(Class guardrailClass) { + if (!isSuccess()) { + if (failures.size() != 1) { + throw new IllegalArgumentException(); + } + failures.set(0, failures.get(0).withGuardrailClass(guardrailClass)); + } + return this; + } + + @Override + public String toString() { + if (isSuccess()) { + return "success"; + } + return failures.stream().map(Failure::toString).collect(Collectors.joining(", ")); + } + + record Failure(String message, Throwable cause, Class guardrailClass, boolean retry, + String reprompt) implements GuardrailResult.Failure { + public Failure(String message) { + this(message, null); + } + + public Failure(String message, Throwable cause) { + this(message, cause, false); + } + + public Failure(String message, Throwable cause, boolean retry) { + this(message, cause, null, retry, null); + } + + public Failure(String message, Throwable cause, boolean retry, String reprompt) { + this(message, cause, null, retry, reprompt); + } + + @Override + public Failure withGuardrailClass(Class guardrailClass) { + return new Failure(message(), cause(), guardrailClass, retry, reprompt); + } + + @Override + public String toString() { + return "The guardrail " + guardrailClass.getName() + " failed with this message: " + message; + } + + } +} diff --git a/core/runtime/src/main/java/io/quarkiverse/langchain4j/guardrails/OutputGuardrails.java b/core/runtime/src/main/java/io/quarkiverse/langchain4j/guardrails/OutputGuardrails.java index 0fd6a742d..906afefbd 100644 --- a/core/runtime/src/main/java/io/quarkiverse/langchain4j/guardrails/OutputGuardrails.java +++ b/core/runtime/src/main/java/io/quarkiverse/langchain4j/guardrails/OutputGuardrails.java @@ -13,9 +13,8 @@ *

* A guardrail is a rule that is applied to the output of the model to ensure that the output is safe and meets the * expectations. - * When a validation fails, the guardrail throws a - * {@link io.quarkiverse.langchain4j.guardrails.OutputGuardrail.ValidationException}. - * The exception can indicate whether the request should be retried and provide a {@code reprompt} message. + * When a validation fails, the result can indicate whether the request should be retried and provide a {@code reprompt} + * message. *

* In the case of reprompting, the reprompt message is added to the LLM context and the request is retried. * diff --git a/core/runtime/src/main/java/io/quarkiverse/langchain4j/guardrails/GuardrailException.java b/core/runtime/src/main/java/io/quarkiverse/langchain4j/runtime/aiservice/GuardrailException.java similarity index 50% rename from core/runtime/src/main/java/io/quarkiverse/langchain4j/guardrails/GuardrailException.java rename to core/runtime/src/main/java/io/quarkiverse/langchain4j/runtime/aiservice/GuardrailException.java index 33f3d8a41..0298387c5 100644 --- a/core/runtime/src/main/java/io/quarkiverse/langchain4j/guardrails/GuardrailException.java +++ b/core/runtime/src/main/java/io/quarkiverse/langchain4j/runtime/aiservice/GuardrailException.java @@ -1,13 +1,15 @@ -package io.quarkiverse.langchain4j.guardrails; +package io.quarkiverse.langchain4j.runtime.aiservice; /** * Exception thrown when a input or output guardrail validation fails. *

- * This exception is not intended to be used in guardrail implementation. Instead, guardrail implementations should throw - * {@link io.quarkiverse.langchain4j.guardrails.OutputGuardrail.ValidationException} or - * {@link io.quarkiverse.langchain4j.guardrails.InputGuardrail.ValidationException} when the validation fails. + * This exception is not intended to be used in guardrail implementation. */ public class GuardrailException extends RuntimeException { + public GuardrailException(String message) { + super(message); + } + public GuardrailException(String message, Throwable cause) { super(message, cause); } diff --git a/core/runtime/src/main/java/io/quarkiverse/langchain4j/runtime/aiservice/GuardrailsSupport.java b/core/runtime/src/main/java/io/quarkiverse/langchain4j/runtime/aiservice/GuardrailsSupport.java index 20ac54bca..876a0f247 100644 --- a/core/runtime/src/main/java/io/quarkiverse/langchain4j/runtime/aiservice/GuardrailsSupport.java +++ b/core/runtime/src/main/java/io/quarkiverse/langchain4j/runtime/aiservice/GuardrailsSupport.java @@ -2,7 +2,9 @@ import static dev.langchain4j.data.message.UserMessage.userMessage; +import java.util.ArrayList; import java.util.List; +import java.util.function.Function; import jakarta.enterprise.inject.spi.CDI; @@ -13,20 +15,27 @@ import dev.langchain4j.model.chat.ChatLanguageModel; import dev.langchain4j.model.output.Response; import dev.langchain4j.rag.AugmentationResult; -import io.quarkiverse.langchain4j.guardrails.GuardrailException; +import io.quarkiverse.langchain4j.guardrails.Guardrail; +import io.quarkiverse.langchain4j.guardrails.GuardrailParams; +import io.quarkiverse.langchain4j.guardrails.GuardrailResult; import io.quarkiverse.langchain4j.guardrails.InputGuardrail; +import io.quarkiverse.langchain4j.guardrails.InputGuardrailResult; import io.quarkiverse.langchain4j.guardrails.OutputGuardrail; +import io.quarkiverse.langchain4j.guardrails.OutputGuardrailResult; public class GuardrailsSupport { public static void invokeInputGuardrails(AiServiceMethodCreateInfo methodCreateInfo, UserMessage userMessage, ChatMemory chatMemory, AugmentationResult augmentationResult) { - InputGuardRailsResult result = invokeInputGuardRails(methodCreateInfo, - new InputGuardrail.InputGuardrailParams(userMessage, chatMemory, augmentationResult)); - if (!result.success()) { - throw new GuardrailException( - "Input validation failed. The guardrail " + result.bean().getName() + " thrown an exception", - result.failure()); + InputGuardrailResult result; + try { + result = invokeInputGuardRails(methodCreateInfo, + new InputGuardrail.InputGuardrailParams(userMessage, chatMemory, augmentationResult)); + } catch (Exception e) { + throw new GuardrailException(e.getMessage(), e); + } + if (!result.isSuccess()) { + throw new GuardrailException(result.toString(), result.getFirstFailureException()); } } @@ -42,15 +51,19 @@ public static Response invokeOutputGuardrails(AiServiceMethodCreateIn max = 1; } while (attempt < max) { - OutputGuardRailsResult grr = invokeOutputGuardRails(methodCreateInfo, output); - if (!grr.success) { - if (!grr.retry()) { - throw new GuardrailException( - "Output validation failed. The guardrail " + grr.bean().getName() + " thrown an exception", - grr.failure()); - } else if (grr.reprompt() != null) { + OutputGuardrailResult result; + try { + result = invokeOutputGuardRails(methodCreateInfo, output); + } catch (Exception e) { + throw new GuardrailException(e.getMessage(), e); + } + + if (!result.isSuccess()) { + if (!result.isRetry()) { + throw new GuardrailException(result.toString(), result.getFirstFailureException()); + } else if (result.getReprompt() != null) { // Retry with re-prompting - chatMemory.add(userMessage(grr.reprompt())); + chatMemory.add(userMessage(result.getReprompt())); if (toolSpecifications == null) { response = chatModel.generate(chatMemory.messages()); } else { @@ -73,17 +86,16 @@ public static Response invokeOutputGuardrails(AiServiceMethodCreateIn } if (attempt == max) { - throw new GuardrailException("Output validation failed. The guardrails have reached the maximum number of retries", - null); + throw new GuardrailException("Output validation failed. The guardrails have reached the maximum number of retries"); } return response; } @SuppressWarnings("unchecked") - private static OutputGuardRailsResult invokeOutputGuardRails(AiServiceMethodCreateInfo methodCreateInfo, + private static OutputGuardrailResult invokeOutputGuardRails(AiServiceMethodCreateInfo methodCreateInfo, OutputGuardrail.OutputGuardrailParams params) { if (methodCreateInfo.getOutputGuardrailsClassNames().isEmpty()) { - return OutputGuardRailsResult.SUCCESS; + return OutputGuardrailResult.success(); } List> classes; synchronized (AiServiceMethodImplementationSupport.class) { @@ -103,24 +115,14 @@ private static OutputGuardRailsResult invokeOutputGuardRails(AiServiceMethodCrea } } - for (Class bean : classes) { - try { - CDI.current().select(bean).get().validate(params); - } catch (OutputGuardrail.ValidationException e) { - return new OutputGuardRailsResult(false, bean, e, e.isRetry(), e.getReprompt()); - } catch (Exception e) { - return new OutputGuardRailsResult(false, bean, e, false, null); - } - } - - return OutputGuardRailsResult.SUCCESS; + return guardrailResult(params, (List) classes, OutputGuardrailResult.success(), OutputGuardrailResult::failure); } @SuppressWarnings("unchecked") - private static InputGuardRailsResult invokeInputGuardRails(AiServiceMethodCreateInfo methodCreateInfo, + private static InputGuardrailResult invokeInputGuardRails(AiServiceMethodCreateInfo methodCreateInfo, InputGuardrail.InputGuardrailParams params) { if (methodCreateInfo.getInputGuardrailsClassNames().isEmpty()) { - return InputGuardRailsResult.SUCCESS; + return InputGuardrailResult.success(); } List> classes; synchronized (AiServiceMethodImplementationSupport.class) { @@ -140,27 +142,34 @@ private static InputGuardRailsResult invokeInputGuardRails(AiServiceMethodCreate } } - for (Class bean : classes) { - try { - CDI.current().select(bean).get().validate(params); - } catch (Exception e) { - return new InputGuardRailsResult(false, bean, e); - } - } - - return InputGuardRailsResult.SUCCESS; + return guardrailResult(params, (List) classes, InputGuardrailResult.success(), InputGuardrailResult::failure); } - private record OutputGuardRailsResult(boolean success, Class bean, Exception failure, - boolean retry, String reprompt) { - - static OutputGuardRailsResult SUCCESS = new OutputGuardRailsResult(true, null, null, false, null); + private static GR guardrailResult(GuardrailParams params, + List> classes, GR accumulatedResults, + Function, GR> producer) { + for (Class bean : classes) { + GR result = (GR) CDI.current().select(bean).get().validate(params).validatedBy(bean); + if (result.isFatal()) { + return result; + } + accumulatedResults = compose(accumulatedResults, result, producer); + } + return accumulatedResults; } - private record InputGuardRailsResult(boolean success, Class bean, Exception failure) { - - static InputGuardRailsResult SUCCESS = new InputGuardRailsResult(true, null, null); - + private static GR compose(GR first, GR second, + Function, GR> producer) { + if (first.isSuccess()) { + return second; + } + if (second.isSuccess()) { + return first; + } + List failures = new ArrayList<>(); + failures.addAll(first.failures()); + failures.addAll(second.failures()); + return producer.apply(failures); } } diff --git a/docs/modules/ROOT/pages/guardrails.adoc b/docs/modules/ROOT/pages/guardrails.adoc index 072b8830e..810674e65 100644 --- a/docs/modules/ROOT/pages/guardrails.adoc +++ b/docs/modules/ROOT/pages/guardrails.adoc @@ -28,27 +28,23 @@ package io.quarkiverse.langchain4j.guardrails; import dev.langchain4j.data.message.UserMessage; import dev.langchain4j.memory.ChatMemory; import dev.langchain4j.rag.AugmentationResult; -import io.smallrye.common.annotation.Experimental; /** * An input guardrail is a rule that is applied to the input of the model to ensure that the input (the user message) is * safe and meets the expectations of the model. *

* Implementation should be exposed as a CDI bean, and the class name configured in {@link InputGuardrails#value()} annotation. - *

- * Implementation should throw a {@link ValidationException} when the validation fails. */ -public interface InputGuardrail { +public interface InputGuardrail extends Guardrail { /** * Validates the {@code user message} that will be sent to the LLM. *

* * @param userMessage the response from the LLM - * @throws ValidationException the exception throws if the validation fails. */ - default void validate(UserMessage userMessage) throws ValidationException { - throw new ValidationException("Validation not implemented"); + default InputGuardrailResult validate(UserMessage userMessage) { + return failure("Validation not implemented"); } /** @@ -61,11 +57,10 @@ public interface InputGuardrail { * * @param params the parameters, including the user message, the memory (maybe null), * and the augmentation result (maybe null). Cannot be {@code null} - * @throws ValidationException the exception throws if the validation fails. */ - default void validate(InputGuardrailParams params) - throws ValidationException { - validate(params.userMessage()); + @Override + default InputGuardrailResult validate(InputGuardrailParams params) { + return validate(params.userMessage()); } /** @@ -75,7 +70,8 @@ public interface InputGuardrail { * @param memory the memory, can be {@code null} or empty * @param augmentationResult the augmentation result, can be {@code null} */ - record InputGuardrailParams(UserMessage userMessage, ChatMemory memory, AugmentationResult augmentationResult) { + record InputGuardrailParams(UserMessage userMessage, ChatMemory memory, + AugmentationResult augmentationResult) implements GuardrailParams { } } @@ -84,8 +80,8 @@ public interface InputGuardrail { The `validate` method of the `InputGuardrail` interface can have two signatures: -- `void validate(UserMessage responseFromLLM) throws ValidationException` -- `void validate(InputGuardrailParams params) throws ValidationException` +- `InputGuardrailParams validate(UserMessage responseFromLLM)` +- `InputGuardrailParams validate(InputGuardrailParams params)` The first one is used when the guardrail only needs the user message. Simple guardrails can use this method. @@ -94,13 +90,11 @@ For example, they can check that there are enough documents in the augmentation ==== Input Guardrails Outcome -Input guardrails can have two outcomes: +Input guardrails can have three outcomes: - _pass_ - The input is valid, the next guardrail is executed, and if the last guardrail passes, the LLM is called. -- _fail_ - The input is invalid, the next guardrail is **not** executed, and the error is rethrown. The LLM is not called. - -A `validate` method completing successfully is considered a pass. -A `validate` method throwing an `Exception` is considered a fail. +- _fail_ - The input is invalid, but the next guardrail is executed the same, in order to accumulate all the possible validation problems. The LLM is not called. +- _fatal_ - The input is invalid, the next guardrail is **not** executed, and the error is rethrown. The LLM is not called. ==== Input Guardrails Scopes @@ -166,10 +160,35 @@ public interface Simulator { } ---- -In this example, the `VerifyHeroFormat` is executed first to check that the passed hero is valid +In this example, the `VerifyHeroFormat` is executed first to check that the passed hero is valid. Then, the `VerifyVillainFormat` is executed to check that the villain is valid. -If the `VerifyHeroFormat` fails, the `VerifyVillainFormat` is not executed. +If the `VerifyHeroFormat` fails, the `VerifyVillainFormat` may or may not be executed depending on whether the failure is fatal or not. For instance the `VerifyHeroFormat` could be implemented as it follows. + +[source,java] +---- +import io.quarkiverse.langchain4j.guardrails.InputGuardrail; +import io.quarkiverse.langchain4j.guardrails.InputGuardrailResult; +import jakarta.enterprise.context.ApplicationScoped; + +@ApplicationScoped +public class VerifyHeroFormat implements InputGuardrail { + + @Override + public InputGuardrailResult validate(dev.langchain4j.data.message.UserMessage um) { + String text = um.singleText(); + if (text.length() > 1000) { + // a fatal failure, the next InputGuardrail won't be evaluated + return fatal("Input too long, size = " + text.length()); + } + if (!text.contains("hero")) { + // a normal failure, still allowing to evaluate also the next InputGuardrail and accumulate multiple failures + return failure("The input should contain the word 'hero'"); + } + return success(); + } +} +---- == Output Guardrails @@ -188,29 +207,24 @@ import dev.langchain4j.memory.ChatMemory; import dev.langchain4j.rag.AugmentationResult; /** - * A guardrail is a rule that is applied to the output of the model to ensure that the output is safe and meets the + * An output guardrail is a rule that is applied to the output of the model to ensure that the output is safe and meets the * expectations. *

* Implementation should be exposed as a CDI bean, and the class name configured in {@link OutputGuardrails#value()} annotation. *

- * Implementation should throw a {@link ValidationException} when the validation fails. The exception can indicate whether the - * request should be retried and provide a {@code reprompt} message. * In the case of reprompting, the reprompt message is added to the LLM context and the request is retried. *

* The maximum number of retries is configurable using {@code quarkus.langchain4j.guardrails.max-retries}, defaulting to 3. */ -public interface OutputGuardrail { +public interface OutputGuardrail extends Guardrail { /** * Validates the response from the LLM. - *

- * If the validation fails with an exception that is not a {@link ValidationException}, no retry will be attempted. * * @param responseFromLLM the response from the LLM - * @throws ValidationException the exception throws if the validation fails. */ - default void validate(AiMessage responseFromLLM) throws ValidationException { - throw new ValidationException("Validation not implemented", false, null); + default OutputGuardrailResult validate(AiMessage responseFromLLM) { + return failure("Validation not implemented"); } /** @@ -219,17 +233,14 @@ public interface OutputGuardrail { * Unlike {@link #validate(AiMessage)}, this method allows to access the memory and the augmentation result (in the * case of a RAG). *

- * If the validation fails with an exception that is not a {@link ValidationException}, no retry will be attempted. - *

* Implementation must not attempt to write to the memory or the augmentation result. * * @param params the parameters, including the response from the LLM, the memory (maybe null), * and the augmentation result (maybe null). Cannot be {@code null} - * @throws ValidationException the exception throws if the validation fails. */ - default void validate(OutputGuardrailParams params) - throws ValidationException { - validate(params.responseFromLLM()); + @Override + default OutputGuardrailResult validate(OutputGuardrailParams params) { + return validate(params.responseFromLLM()); } /** @@ -239,17 +250,17 @@ public interface OutputGuardrail { * @param memory the memory, can be {@code null} or empty * @param augmentationResult the augmentation result, can be {@code null} */ - record OutputGuardrailParams(AiMessage responseFromLLM, ChatMemory memory, AugmentationResult augmentationResult) { + record OutputGuardrailParams(AiMessage responseFromLLM, ChatMemory memory, + AugmentationResult augmentationResult) implements GuardrailParams { } - } ---- The `validate` method of the `OutputGuardrail` interface can have two signatures: -- `void validate(AiMessage responseFromLLM) throws ValidationException` -- `void validate(OutputGuardrailParams params) throws ValidationException` +- `OutputGuardrailParams validate(AiMessage responseFromLLM)` +- `OutputGuardrailParams validate(OutputGuardrailParams params)` The first one is used when the guardrail only needs the output of the LLM. Simple guardrails can use this method. @@ -270,12 +281,13 @@ public class JsonGuardrail implements OutputGuardrail { ObjectMapper mapper; @Override - public void validate(AiMessage responseFromLLM) throws ValidationException { + public OutputGuardrailResult validate(AiMessage responseFromLLM) { try { mapper.readTree(responseFromLLM.text()); } catch (Exception e) { - throw new ValidationException("Invalid JSON", true, "Make sure you return a valid JSON object"); + return reprompt("Invalid JSON", e, "Make sure you return a valid JSON object"); } + return success(); } } @@ -287,31 +299,28 @@ The <<_detecting_hallucinations_in_the_rag_context>> section gives an example of ==== Output Guardrails Outcome -Output guardrails can have four outcomes: +Output guardrails can have five outcomes: - _pass_ - The output is valid, the next guardrail is executed, and if the last guardrail passes, the output is returned to the caller. -- _fail_ - The output is invalid, the next guardrail is **not** executed, and the error is rethrown. -- _fail with retry_ - The output is invalid, the next guardrail is **not** executed, and the LLM is called again with the **same** prompt. -- _fail with reprompt_ - The output is invalid, the next guardrail is **not** executed, and the LLM is called again with a **new** prompt. +- _fail_ - The output is invalid, but the next guardrail is executed the same, in order to accumulate all the possible validation problems. +- _fatal_ - The output is invalid, the next guardrail is **not** executed, and the error is rethrown. +- _fatal with retry_ - The output is invalid, the next guardrail is **not** executed, and the LLM is called again with the **same** prompt. +- _fatal with reprompt_ - The output is invalid, the next guardrail is **not** executed, and the LLM is called again with a **new** prompt. -A `validate` method completing successfully is considered a pass. -A `validate` method throwing an `Exception` is considered a fail. -If that exception is a `io.quarkiverse.langchain4j.guardrails.OutputGuardrail.ValidationException` exception, then the guardrail can specify whether the LLM should be retried or reprompted. +In fact if the validation fails, then the guardrail can specify whether the LLM should be retried or reprompted. [source,java] ---- // Retry - The LLM is called again with the same prompt and context // The guardrails will be called again with the new output -throw new ValidationException("Invalid JSON", true, null); +return retry("Invalid JSON"); // Retry with reprompt - The LLM is called again with a new prompt and context // A new user message is added to the LLM context (memory), and the LLM is called again with this new context. // The guardrails will be called again with the new output -throw new ValidationException("Invalid JSON", true, "Make sure you return a valid JSON object"); +return reprompt("Invalid JSON", "Make sure you return a valid JSON object"); ---- -IMPORTANT: _Reprompting_ requires the `retry` parameter to be set to `true` in the `ValidationException` constructor. - By default, Quarkus Langchain4J will limit the number of retries to 3. This is configurable using the `quarkus.langchain4j.guardrails.max-retries` configuration property: @@ -411,6 +420,8 @@ import dev.langchain4j.data.embedding.Embedding; import dev.langchain4j.model.output.Response; import dev.langchain4j.rag.content.Content; import io.quarkiverse.langchain4j.guardrails.OutputGuardrail; +import io.quarkiverse.langchain4j.guardrails.OutputGuardrailParams; +import io.quarkiverse.langchain4j.guardrails.OutputGuardrailResult; import io.quarkus.logging.Log; import jakarta.enterprise.context.ApplicationScoped; import jakarta.inject.Inject; @@ -426,11 +437,11 @@ public class HallucinationGuard implements OutputGuardrail { double threshold; @Override - public void validate(OutputGuardrailParams params) throws ValidationException { + public OutputGuardrailResult validate(OutputGuardrailParams params) { Response embeddingOfTheResponse = embedding.embed(params.responseFromLLM().text()); if (params.augmentationResult() == null || params.augmentationResult().contents().isEmpty()) { Log.info("No content to validate against"); - return; + return success(); } float[] vectorOfTheResponse = embeddingOfTheResponse.content().vector(); for (Content content : params.augmentationResult().contents()) { @@ -439,11 +450,11 @@ public class HallucinationGuard implements OutputGuardrail { double distance = cosineDistance(vectorOfTheResponse, vectorOfTheContent); if (distance < threshold) { Log.info("Passed hallucination guardrail: " + distance); - return; + return success(); } } - throw new ValidationException("Hallucination detected", true, "Make sure you use the given documents to produce the response"); + return reprompt("Hallucination detected", "Make sure you use the given documents to produce the response"); } public static double cosineDistance(float[] vector1, float[] vector2) {