Skip to content

Commit

Permalink
Merge pull request #883 from mariofusco/guardrails-api
Browse files Browse the repository at this point in the history
Avoid using exceptions for control flow in guardrail's API
  • Loading branch information
geoand authored Sep 16, 2024
2 parents 20ed9ff + 0db2df0 commit b61bc88
Show file tree
Hide file tree
Showing 29 changed files with 697 additions and 335 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -33,8 +33,10 @@
import dev.langchain4j.service.UserMessage;
import io.quarkiverse.langchain4j.RegisterAiService;
import io.quarkiverse.langchain4j.guardrails.InputGuardrail;
import io.quarkiverse.langchain4j.guardrails.InputGuardrailResult;
import io.quarkiverse.langchain4j.guardrails.InputGuardrails;
import io.quarkiverse.langchain4j.guardrails.OutputGuardrail;
import io.quarkiverse.langchain4j.guardrails.OutputGuardrailResult;
import io.quarkiverse.langchain4j.guardrails.OutputGuardrails;
import io.quarkus.test.QuarkusUnitTest;
import io.smallrye.mutiny.Multi;
Expand Down Expand Up @@ -121,9 +123,10 @@ public static class MyInputGuardrail implements InputGuardrail {
AtomicInteger spy = new AtomicInteger();

@Override
public void validate(InputGuardrailParams params) throws ValidationException {
public InputGuardrailResult validate(InputGuardrailParams params) {
spy.incrementAndGet();
assertThat(params.augmentationResult().contents()).hasSize(2);
return InputGuardrailResult.success();
}

public int getSpy() {
Expand All @@ -137,9 +140,10 @@ public static class MyOutputGuardrail implements OutputGuardrail {
AtomicInteger spy = new AtomicInteger();

@Override
public void validate(OutputGuardrailParams params) throws ValidationException {
public OutputGuardrailResult validate(OutputGuardrailParams params) {
spy.incrementAndGet();
assertThat(params.augmentationResult().contents()).hasSize(2);
return OutputGuardrailResult.success();
}

public int getSpy() {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -27,8 +27,10 @@
import dev.langchain4j.service.UserMessage;
import io.quarkiverse.langchain4j.RegisterAiService;
import io.quarkiverse.langchain4j.guardrails.InputGuardrail;
import io.quarkiverse.langchain4j.guardrails.InputGuardrailResult;
import io.quarkiverse.langchain4j.guardrails.InputGuardrails;
import io.quarkiverse.langchain4j.guardrails.OutputGuardrail;
import io.quarkiverse.langchain4j.guardrails.OutputGuardrailResult;
import io.quarkiverse.langchain4j.guardrails.OutputGuardrails;
import io.quarkus.test.QuarkusUnitTest;

Expand All @@ -38,7 +40,8 @@ public class InputAndOutputGuardrailsTest {
static final QuarkusUnitTest unitTest = new QuarkusUnitTest()
.setArchiveProducer(() -> ShrinkWrap.create(JavaArchive.class)
.addClasses(MyAiService.class,
MyChatModel.class, MyChatModelSupplier.class, MyMemoryProviderSupplier.class));
MyChatModel.class, MyChatModelSupplier.class, MyMemoryProviderSupplier.class,
ValidationException.class));

@Inject
MyOkInputGuardrail okIn;
Expand Down Expand Up @@ -144,8 +147,9 @@ public static class MyOkInputGuardrail implements InputGuardrail {
AtomicInteger spy = new AtomicInteger();

@Override
public void validate(InputGuardrailParams params) throws ValidationException {
public InputGuardrailResult validate(InputGuardrailParams params) {
spy.incrementAndGet();
return success();
}

public int getSpy() {
Expand All @@ -159,9 +163,9 @@ public static class MyKoInputGuardrail implements InputGuardrail {
AtomicInteger spy = new AtomicInteger();

@Override
public void validate(InputGuardrailParams params) throws ValidationException {
public InputGuardrailResult validate(InputGuardrailParams params) {
spy.incrementAndGet();
throw new ValidationException("boom");
return failure("boom", new ValidationException("boom"));
}

public int getSpy() {
Expand All @@ -175,8 +179,9 @@ public static class MyOkOutputGuardrail implements OutputGuardrail {
AtomicInteger spy = new AtomicInteger();

@Override
public void validate(OutputGuardrailParams params) throws ValidationException {
public OutputGuardrailResult validate(OutputGuardrailParams params) {
spy.incrementAndGet();
return OutputGuardrailResult.success();
}

public int getSpy() {
Expand All @@ -190,9 +195,9 @@ public static class MyKoOutputGuardrail implements OutputGuardrail {
AtomicInteger spy = new AtomicInteger();

@Override
public void validate(OutputGuardrailParams params) throws ValidationException {
public OutputGuardrailResult validate(OutputGuardrailParams params) {
spy.incrementAndGet();
throw new ValidationException("boom", false, null);
return failure("boom", new ValidationException("boom"));
}

public int getSpy() {
Expand All @@ -206,10 +211,11 @@ public static class MyKoWithRetryOutputGuardrail implements OutputGuardrail {
AtomicInteger spy = new AtomicInteger();

@Override
public void validate(OutputGuardrailParams params) throws ValidationException {
public OutputGuardrailResult validate(OutputGuardrailParams params) {
if (spy.incrementAndGet() == 1) {
throw new ValidationException("KO", true, null);
return retry("KO");
}
return success();
}

public int getSpy() {
Expand All @@ -223,10 +229,11 @@ public static class MyKoWithRepromprOutputGuardrail implements OutputGuardrail {
AtomicInteger spy = new AtomicInteger();

@Override
public void validate(OutputGuardrailParams params) throws ValidationException {
public OutputGuardrailResult validate(OutputGuardrailParams params) {
if (spy.incrementAndGet() == 1) {
throw new ValidationException("KO", true, "retry");
return reprompt("KO", "retry");
}
return success();
}

public int getSpy() {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -26,9 +26,10 @@
import dev.langchain4j.service.MemoryId;
import dev.langchain4j.service.UserMessage;
import io.quarkiverse.langchain4j.RegisterAiService;
import io.quarkiverse.langchain4j.guardrails.GuardrailException;
import io.quarkiverse.langchain4j.guardrails.InputGuardrail;
import io.quarkiverse.langchain4j.guardrails.InputGuardrailResult;
import io.quarkiverse.langchain4j.guardrails.InputGuardrails;
import io.quarkiverse.langchain4j.runtime.aiservice.GuardrailException;
import io.quarkiverse.langchain4j.runtime.aiservice.NoopChatMemory;
import io.quarkus.test.QuarkusUnitTest;

Expand All @@ -38,7 +39,8 @@ public class InputGuardrailChainTest {
static final QuarkusUnitTest unitTest = new QuarkusUnitTest()
.setArchiveProducer(() -> ShrinkWrap.create(JavaArchive.class)
.addClasses(MyAiService.class,
MyChatModel.class, MyChatModelSupplier.class, MyMemoryProviderSupplier.class));
MyChatModel.class, MyChatModelSupplier.class, MyMemoryProviderSupplier.class,
ValidationException.class));

@Inject
MyAiService aiService;
Expand Down Expand Up @@ -74,7 +76,7 @@ void testThatGuardrailOrderIsCorrect() {
void testFailureTheChain() {
assertThatThrownBy(() -> aiService.failingFirstTwo("1", "foo"))
.isInstanceOf(GuardrailException.class)
.hasCauseInstanceOf(InputGuardrail.ValidationException.class)
.hasCauseInstanceOf(ValidationException.class)
.hasRootCauseMessage("boom");
assertThat(firstGuardrail.spy()).isEqualTo(1);
assertThat(secondGuardrail.spy()).isEqualTo(0);
Expand Down Expand Up @@ -102,14 +104,15 @@ public static class FirstGuardrail implements InputGuardrail {
AtomicLong lastAccess = new AtomicLong();

@Override
public void validate(dev.langchain4j.data.message.UserMessage um) {
public InputGuardrailResult validate(dev.langchain4j.data.message.UserMessage um) {
spy.incrementAndGet();
lastAccess.set(System.nanoTime());
try {
Thread.sleep(1);
} catch (InterruptedException e) {
// Ignore me
}
return success();
}

public int spy() {
Expand All @@ -128,14 +131,15 @@ public static class SecondGuardrail implements InputGuardrail {
volatile AtomicLong lastAccess = new AtomicLong();

@Override
public void validate(dev.langchain4j.data.message.UserMessage um) {
public InputGuardrailResult validate(dev.langchain4j.data.message.UserMessage um) {
spy.incrementAndGet();
lastAccess.set(System.nanoTime());
try {
Thread.sleep(1);
} catch (InterruptedException e) {
// Ignore me
}
return success();
}

public int spy() {
Expand All @@ -153,10 +157,11 @@ public static class FailingGuardrail implements InputGuardrail {
AtomicInteger spy = new AtomicInteger(0);

@Override
public void validate(dev.langchain4j.data.message.UserMessage um) throws ValidationException {
public InputGuardrailResult validate(dev.langchain4j.data.message.UserMessage um) {
if (spy.incrementAndGet() == 1) {
throw new ValidationException("boom");
return fatal("boom", new ValidationException("boom"));
}
return success();
}

public int spy() {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
import dev.langchain4j.service.UserMessage;
import io.quarkiverse.langchain4j.RegisterAiService;
import io.quarkiverse.langchain4j.guardrails.InputGuardrail;
import io.quarkiverse.langchain4j.guardrails.InputGuardrailResult;
import io.quarkiverse.langchain4j.guardrails.InputGuardrails;
import io.quarkiverse.langchain4j.runtime.aiservice.NoopChatMemory;
import io.quarkus.test.QuarkusUnitTest;
Expand Down Expand Up @@ -59,7 +60,7 @@ public interface MyAiService {
public static class MissingGuardRail implements InputGuardrail {

@Override
public void validate(dev.langchain4j.data.message.UserMessage um) {
public InputGuardrailResult validate(dev.langchain4j.data.message.UserMessage um) {
throw new RuntimeException("Should not be invoked");
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
import dev.langchain4j.service.UserMessage;
import io.quarkiverse.langchain4j.RegisterAiService;
import io.quarkiverse.langchain4j.guardrails.InputGuardrail;
import io.quarkiverse.langchain4j.guardrails.InputGuardrailResult;
import io.quarkiverse.langchain4j.guardrails.InputGuardrails;
import io.quarkiverse.langchain4j.runtime.aiservice.NoopChatMemory;
import io.quarkus.test.QuarkusUnitTest;
Expand Down Expand Up @@ -74,8 +75,9 @@ public static class OKGuardrail implements InputGuardrail {
AtomicInteger spy = new AtomicInteger(0);

@Override
public void validate(dev.langchain4j.data.message.UserMessage um) {
public InputGuardrailResult validate(dev.langchain4j.data.message.UserMessage um) {
spy.incrementAndGet();
return success();
}

public int spy() {
Expand All @@ -89,9 +91,9 @@ public static class KOGuardrail implements InputGuardrail {
AtomicInteger spy = new AtomicInteger(0);

@Override
public void validate(dev.langchain4j.data.message.UserMessage um) throws ValidationException {
public InputGuardrailResult validate(dev.langchain4j.data.message.UserMessage um) {
spy.incrementAndGet();
throw new ValidationException("KO");
return failure("KO");
}

public int spy() {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
import dev.langchain4j.service.UserMessage;
import io.quarkiverse.langchain4j.RegisterAiService;
import io.quarkiverse.langchain4j.guardrails.InputGuardrail;
import io.quarkiverse.langchain4j.guardrails.InputGuardrailResult;
import io.quarkiverse.langchain4j.guardrails.InputGuardrails;
import io.quarkiverse.langchain4j.runtime.aiservice.NoopChatMemory;
import io.quarkus.test.QuarkusUnitTest;
Expand Down Expand Up @@ -68,8 +69,9 @@ public static class OKGuardrail implements InputGuardrail {
AtomicInteger spy = new AtomicInteger(0);

@Override
public void validate(dev.langchain4j.data.message.UserMessage ignored) {
public InputGuardrailResult validate(dev.langchain4j.data.message.UserMessage ignored) {
spy.incrementAndGet();
return success();
}

public int spy() {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@
import dev.langchain4j.service.UserMessage;
import io.quarkiverse.langchain4j.RegisterAiService;
import io.quarkiverse.langchain4j.guardrails.InputGuardrail;
import io.quarkiverse.langchain4j.guardrails.InputGuardrailResult;
import io.quarkiverse.langchain4j.guardrails.InputGuardrails;
import io.quarkiverse.langchain4j.runtime.aiservice.NoopChatMemory;
import io.quarkus.arc.Arc;
Expand All @@ -38,7 +39,8 @@ public class InputGuardrailTest {
static final QuarkusUnitTest unitTest = new QuarkusUnitTest()
.setArchiveProducer(() -> ShrinkWrap.create(JavaArchive.class)
.addClasses(MyAiService.class, OKGuardrail.class, KOGuardrail.class,
MyChatModel.class, MyChatModelSupplier.class, MyMemoryProviderSupplier.class));
MyChatModel.class, MyChatModelSupplier.class, MyMemoryProviderSupplier.class,
ValidationException.class));

@Inject
MyAiService aiService;
Expand Down Expand Up @@ -73,10 +75,10 @@ void testThatInputGuardrailsAreInvoked() {
void testThatGuardrailCanThrowValidationException() {
assertThat(koGuardrail.spy()).isEqualTo(0);
assertThatThrownBy(() -> aiService.ko("1"))
.hasCauseExactlyInstanceOf(InputGuardrail.ValidationException.class);
.hasCauseExactlyInstanceOf(ValidationException.class);
assertThat(koGuardrail.spy()).isEqualTo(1);
assertThatThrownBy(() -> aiService.ko("1"))
.hasCauseExactlyInstanceOf(InputGuardrail.ValidationException.class);
.hasCauseExactlyInstanceOf(ValidationException.class);
assertThat(koGuardrail.spy()).isEqualTo(2);
}

Expand All @@ -99,8 +101,9 @@ public static class OKGuardrail implements InputGuardrail {
AtomicInteger spy = new AtomicInteger(0);

@Override
public void validate(dev.langchain4j.data.message.UserMessage um) {
public InputGuardrailResult validate(dev.langchain4j.data.message.UserMessage um) {
spy.incrementAndGet();
return success();
}

public int spy() {
Expand All @@ -114,9 +117,9 @@ public static class KOGuardrail implements InputGuardrail {
AtomicInteger spy = new AtomicInteger(0);

@Override
public void validate(dev.langchain4j.data.message.UserMessage um) throws ValidationException {
public InputGuardrailResult validate(dev.langchain4j.data.message.UserMessage um) {
spy.incrementAndGet();
throw new ValidationException("KO");
return failure("KO", new ValidationException("KO"));
}

public int spy() {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -29,9 +29,10 @@
import dev.langchain4j.service.MemoryId;
import dev.langchain4j.service.UserMessage;
import io.quarkiverse.langchain4j.RegisterAiService;
import io.quarkiverse.langchain4j.guardrails.GuardrailException;
import io.quarkiverse.langchain4j.guardrails.InputGuardrail;
import io.quarkiverse.langchain4j.guardrails.InputGuardrailResult;
import io.quarkiverse.langchain4j.guardrails.InputGuardrails;
import io.quarkiverse.langchain4j.runtime.aiservice.GuardrailException;
import io.quarkus.test.QuarkusUnitTest;
import io.smallrye.mutiny.Multi;

Expand All @@ -42,7 +43,7 @@ public class InputGuardrailValidationTest {
.setArchiveProducer(() -> ShrinkWrap.create(JavaArchive.class)
.addClasses(MyAiService.class, OKGuardrail.class, KOGuardrail.class,
MyChatModel.class, MyStreamingChatModel.class, MyChatModelSupplier.class,
MyMemoryProviderSupplier.class));
MyMemoryProviderSupplier.class, ValidationException.class));

@Inject
MyAiService aiService;
Expand Down Expand Up @@ -130,8 +131,9 @@ public static class OKGuardrail implements InputGuardrail {
AtomicInteger spy = new AtomicInteger(0);

@Override
public void validate(dev.langchain4j.data.message.UserMessage um) {
public InputGuardrailResult validate(dev.langchain4j.data.message.UserMessage um) {
spy.incrementAndGet();
return success();
}

public int spy() {
Expand All @@ -145,9 +147,9 @@ public static class KOGuardrail implements InputGuardrail {
AtomicInteger spy = new AtomicInteger(0);

@Override
public void validate(dev.langchain4j.data.message.UserMessage um) throws ValidationException {
public InputGuardrailResult validate(dev.langchain4j.data.message.UserMessage um) {
spy.incrementAndGet();
throw new ValidationException("KO");
return failure("KO");
}

public int spy() {
Expand All @@ -161,7 +163,7 @@ public static class KOFatalGuardrail implements InputGuardrail {
AtomicInteger spy = new AtomicInteger(0);

@Override
public void validate(dev.langchain4j.data.message.UserMessage um) throws ValidationException {
public InputGuardrailResult validate(dev.langchain4j.data.message.UserMessage um) {
spy.incrementAndGet();
throw new IllegalArgumentException("Fatal");
}
Expand All @@ -177,7 +179,7 @@ public static class MemoryCheck implements InputGuardrail {
AtomicInteger spy = new AtomicInteger(0);

@Override
public void validate(InputGuardrail.InputGuardrailParams params) {
public InputGuardrailResult validate(InputGuardrail.InputGuardrailParams params) {
spy.incrementAndGet();
if (params.memory().messages().isEmpty()) {
assertThat(params.userMessage().singleText()).isEqualTo("foo");
Expand All @@ -187,6 +189,7 @@ public void validate(InputGuardrail.InputGuardrailParams params) {
assertThat(params.memory().messages().get(1).text()).isEqualTo("Hi!");
assertThat(params.userMessage().singleText()).isEqualTo("bar");
}
return success();
}

public int spy() {
Expand Down
Loading

0 comments on commit b61bc88

Please sign in to comment.