Skip to content

Commit

Permalink
Merge pull request #1174 from edeandrea/add-testing-core
Browse files Browse the repository at this point in the history
Add a core testing module with some custom AssertJ assertions for output guardrails
  • Loading branch information
geoand authored Dec 20, 2024
2 parents d3bdc94 + cde48bc commit b7f861a
Show file tree
Hide file tree
Showing 15 changed files with 774 additions and 2 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -80,5 +80,7 @@ interface Failure {
String message();

Throwable cause();

Class<? extends Guardrail> guardrailClass();
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,7 @@ public String toString() {
return failures.stream().map(Failure::toString).collect(Collectors.joining(", "));
}

record Failure(String message, Throwable cause,
public record Failure(String message, Throwable cause,
Class<? extends Guardrail> guardrailClass) implements GuardrailResult.Failure {
public Failure(String message) {
this(message, null);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -96,7 +96,7 @@ public String toString() {
return failures.stream().map(Failure::toString).collect(Collectors.joining(", "));
}

record Failure(String message, Throwable cause, Class<? extends Guardrail> guardrailClass, boolean retry,
public record Failure(String message, Throwable cause, Class<? extends Guardrail> guardrailClass, boolean retry,
String reprompt) implements GuardrailResult.Failure {
public Failure(String message) {
this(message, null);
Expand Down
226 changes: 226 additions & 0 deletions docs/modules/ROOT/pages/guardrails.adoc
Original file line number Diff line number Diff line change
Expand Up @@ -623,3 +623,229 @@ public class CustomersExtractionOutputGuardrail extends AbstractJsonExtractorOut
}
}
----

=== Unit testing
You can also unit test your output guardrails. A set of AssertJ custom assertions (following https://assertj.github.io/doc/#assertj-core-custom-assertions-entry-point[AssertJ's custom assertions pattern]) are available to help you unit test your guardrails.

To get access to these helpers you'll need to add a `test` dependency to `io.quarkiverse.langchain4j:quarkus-langchain4j-testing-core`:

[source,xml,subs=attributes+]
----
<dependency>
<groupId>io.quarkiverse.langchain4j</groupId>
<artifactId>quarkus-langchain4j-testing-core</artifactId>
<version>{project-version}</version>
<scope>test</scope>
</dependency>
----

Then you can import the helpers into your test class:

[source,java,subs=attributes+]
----
import static io.quarkiverse.langchain4j.guardrails.GuardrailAssertions.*;
----

Then, for an `OutputGuardrail` class that looks something like this:

.src/main/java/io/quarkiverse/langchain4j/guardrails/EmailContainsRequiredInformationOutputGuardrail.java
[source,java,subs="+attributes,macros+"]
----
package io.quarkiverse.langchain4j.guardrails;
import java.util.Optional;
import dev.langchain4j.data.message.AiMessage;
public class EmailContainsRequiredInformationOutputGuardrail implements OutputGuardrail {
static final String NO_RESPONSE_MESSAGE = "No response found";
static final String NO_RESPONSE_PROMPT = "The response was empty. Please try again.";
static final String CLIENT_NAME_NOT_FOUND_MESSAGE = "Client name not found";
static final String CLIENT_NAME_NOT_FOUND_PROMPT = "The response did not contain the client name. Please include the client name \"%s\", exactly as is (case-sensitive), in the email body.";
static final String CLAIM_NUMBER_NOT_FOUND_MESSAGE = "Claim number not found";
static final String CLAIM_NUMBER_NOT_FOUND_PROMPT = "The response did not contain the claim number. Please include the claim number \"%s\", exactly as is (case-sensitive), in the email body.";
static final String CLAIM_STATUS_NOT_FOUND_MESSAGE = "Claim status not found";
static final String CLAIM_STATUS_NOT_FOUND_PROMPT = "The response did not contain the claim status. Please include the claim status \"%s\", exactly as is (case-sensitive), in the email body.";
@Override
public OutputGuardrailResult validate(OutputGuardrailParams params) {
var claimInfo = Optional.ofNullable(params.variables())
.map(vars -> vars.get("claimInfo"))
.map(ClaimInfo.class::cast)
.orElse(null);
if (claimInfo != null) {
var response = Optional.ofNullable(params.responseFromLLM())
.map(AiMessage::text)
.orElse("");
if (response.isBlank()) {
return reprompt(NO_RESPONSE_MESSAGE, NO_RESPONSE_PROMPT);
}
if (!claimInfo.clientName().isBlank() && !response.contains(claimInfo.clientName())) {
return reprompt(CLIENT_NAME_NOT_FOUND_MESSAGE, CLIENT_NAME_NOT_FOUND_PROMPT.formatted(claimInfo.clientName()));
}
if (!claimInfo.claimNumber().isBlank() && !response.contains(claimInfo.claimNumber())) {
return reprompt(CLAIM_NUMBER_NOT_FOUND_MESSAGE,
CLAIM_NUMBER_NOT_FOUND_PROMPT.formatted(claimInfo.claimNumber()));
}
if (!claimInfo.claimStatus().isBlank() && !response.contains(claimInfo.claimStatus())) {
return reprompt(CLAIM_STATUS_NOT_FOUND_MESSAGE,
CLAIM_STATUS_NOT_FOUND_PROMPT.formatted(claimInfo.claimStatus()));
}
}
return success();
}
}
----

You can test it by doing things like this:

.src/test/java/io/quarkiverse/langchain4j/guardrails/EmailContainsRequiredInformationOutputGuardrailTests.java
[source,java,subs="+attributes,macros+"]
----
package io.quarkiverse.langchain4j.guardrails;
import static io.quarkiverse.langchain4j.guardrails.GuardrailAssertions.assertThat;
import static org.mockito.Mockito.spy;
import static org.mockito.Mockito.verify;
import static org.mockito.Mockito.verifyNoMoreInteractions;
import java.util.Map;
import java.util.stream.Stream;
import org.junit.jupiter.api.Test;
import org.junit.jupiter.params.ParameterizedTest;
import org.junit.jupiter.params.provider.Arguments;
import org.junit.jupiter.params.provider.MethodSource;
import dev.langchain4j.data.message.AiMessage;
import io.quarkiverse.langchain4j.guardrails.GuardrailResult.Result;
import io.quarkiverse.langchain4j.guardrails.OutputGuardrailResult.Failure;
class EmailContainsRequiredInformationOutputGuardrailTests {
private static final String CLAIM_NUMBER = "CLM195501";
private static final String CLAIM_STATUS = "denied";
private static final String CLIENT_NAME = "Marty McFly";
private static final String EMAIL_TEMPLATE = """
Dear %s,
We are writing to inform you that your claim (%s) has been reviewed and is currently under consideration. After careful evaluation of the evidence provided, we regret to inform you that your claim has been %s.
Please note that our decision is based on the information provided in your policy declarations page, as well as applicable laws and regulations governing vehicle insurance claims.
If you have any questions or concerns regarding this decision, please do not hesitate to contact us at 800-CAR-SAFE or email claims@parasol.com. A member of our team will be happy to assist you.
Sincerely,
Parasoft Insurance Claims Department
--------------------------------------------
Please note this is an unmonitored email box.
Should you choose to reply, nobody (not even an AI bot) will see your message.
Call a real human should you have any questions. 1-800-CAR-SAFE.
""";
EmailContainsRequiredInformationOutputGuardrail guardrail = spy(new EmailContainsRequiredInformationOutputGuardrail());
@Test
void guardrailSuccess() {
var response = EMAIL_TEMPLATE.formatted(CLIENT_NAME, CLAIM_NUMBER, CLAIM_STATUS);
var params = createParams(response, CLAIM_NUMBER, CLAIM_STATUS, CLIENT_NAME);
assertThat(this.guardrail.validate(params))
.isSuccessful(); // Custom assertion from OutputGuardrailResultAssert
verify(this.guardrail).validate(params);
verify(this.guardrail).success();
verifyNoMoreInteractions(this.guardrail);
}
@Test
void emptyEmail() {
var params = createParams("", CLAIM_NUMBER, CLAIM_STATUS, CLIENT_NAME);
var result = this.guardrail.validate(params);
assertThat(result)
// custom assertions from OutputGuardrailResultAssert
.hasFailures()
.hasResult(Result.FATAL)
.hasSingleFailureWithMessage(EmailContainsRequiredInformationOutputGuardrail.NO_RESPONSE_MESSAGE);
verify(this.guardrail).validate(params);
verify(this.guardrail).reprompt(EmailContainsRequiredInformationOutputGuardrail.NO_RESPONSE_MESSAGE,
EmailContainsRequiredInformationOutputGuardrail.NO_RESPONSE_PROMPT);
verifyNoMoreInteractions(this.guardrail);
}
@ParameterizedTest
@MethodSource("emailDoesntContainRequiredInfoParams")
void emailDoesntContainRequiredInfo(ClaimInfo missingClaimInfo, String expectedRepromptMessage, String expectedRepromptPrompt) {
var responseWithMissingInfo = EMAIL_TEMPLATE.formatted(missingClaimInfo.clientName(), missingClaimInfo.claimNumber(), missingClaimInfo.claimStatus());
var params = createParams(responseWithMissingInfo, CLAIM_NUMBER, CLAIM_STATUS, CLIENT_NAME);
var result = this.guardrail.validate(params);
assertThat(result)
// custom assertions from OutputGuardrailResultAssert
.hasFailures()
.hasResult(Result.FATAL)
.hasSingleFailureWithMessageAndReprompt(expectedRepromptMessage, expectedRepromptPrompt)
.assertSingleFailureSatisfies(failure ->
assertThat(failure)
.isNotNull()
.extracting(
Failure::retry,
Failure::message,
Failure::cause
)
.containsExactly(
true,
expectedRepromptMessage,
null)
);
verify(this.guardrail).validate(params);
verify(this.guardrail).reprompt(expectedRepromptMessage, expectedRepromptPrompt);
verifyNoMoreInteractions(this.guardrail);
}
static Stream<Arguments> emailDoesntContainRequiredInfoParams() {
return Stream.of(
Arguments.of(
new ClaimInfo("", CLAIM_NUMBER, CLAIM_STATUS),
EmailContainsRequiredInformationOutputGuardrail.CLIENT_NAME_NOT_FOUND_MESSAGE,
EmailContainsRequiredInformationOutputGuardrail.CLIENT_NAME_NOT_FOUND_PROMPT.formatted(CLIENT_NAME)
),
Arguments.of(
new ClaimInfo(CLIENT_NAME, "", CLAIM_STATUS),
EmailContainsRequiredInformationOutputGuardrail.CLAIM_NUMBER_NOT_FOUND_MESSAGE,
EmailContainsRequiredInformationOutputGuardrail.CLAIM_NUMBER_NOT_FOUND_PROMPT.formatted(CLAIM_NUMBER)
),
Arguments.of(
new ClaimInfo(CLIENT_NAME, CLAIM_NUMBER, ""),
EmailContainsRequiredInformationOutputGuardrail.CLAIM_STATUS_NOT_FOUND_MESSAGE,
EmailContainsRequiredInformationOutputGuardrail.CLAIM_STATUS_NOT_FOUND_PROMPT.formatted(CLAIM_STATUS)
)
);
}
private static OutputGuardrailParams createParams(String response, String claimNumber, String claimStatus, String clientName) {
return createParams(response, new ClaimInfo(clientName, claimNumber, claimStatus));
}
private static OutputGuardrailParams createParams(String response, ClaimInfo claimInfo) {
return new OutputGuardrailParams(
AiMessage.from(response),
null,
null,
null,
Map.of("claimInfo", claimInfo)
);
}
}
----

See https://github.com/quarkiverse/quarkus-langchain4j/blob/main/testing/core/src/main/java/io/quarkiverse/langchain4j/guardrails/OutputGuardrailResultAssert.java[`OutputGuardrailResultAssert.java`] for more information about the different kinds of asserts you can do.
39 changes: 39 additions & 0 deletions testing/core/pom.xml
Original file line number Diff line number Diff line change
@@ -0,0 +1,39 @@
<?xml version="1.0" encoding="UTF-8"?>
<project xmlns="http://maven.apache.org/POM/4.0.0"
xmlns:xsi="http://www.w3.org/2001/XMLSchema-instance"
xsi:schemaLocation="http://maven.apache.org/POM/4.0.0 http://maven.apache.org/xsd/maven-4.0.0.xsd">
<modelVersion>4.0.0</modelVersion>

<parent>
<groupId>io.quarkiverse.langchain4j</groupId>
<artifactId>quarkus-langchain4j-testing</artifactId>
<version>999-SNAPSHOT</version>
</parent>

<artifactId>quarkus-langchain4j-testing-core</artifactId>
<name>Quarkus LangChain4j - Testing - Core</name>
<description>Provides testing utilities common to the core module</description>

<dependencies>
<dependency>
<groupId>io.quarkiverse.langchain4j</groupId>
<artifactId>quarkus-langchain4j-core</artifactId>
<version>${project.version}</version>
</dependency>
<dependency>
<groupId>org.assertj</groupId>
<artifactId>assertj-core</artifactId>
<version>${assertj.version}</version>
</dependency>
<dependency>
<groupId>org.mockito</groupId>
<artifactId>mockito-core</artifactId>
<scope>test</scope>
</dependency>
<dependency>
<groupId>io.quarkus</groupId>
<artifactId>quarkus-junit5-internal</artifactId>
<scope>test</scope>
</dependency>
</dependencies>
</project>
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
package io.quarkiverse.langchain4j.guardrails;

import org.assertj.core.api.Assertions;

/**
* Custom assertions for working with Guardrails
* <p>
* This follows the pattern described in https://assertj.github.io/doc/#assertj-core-custom-assertions-entry-point
* </p>
*/
public class GuardrailAssertions extends Assertions {
public static OutputGuardrailResultAssert assertThat(OutputGuardrailResult actual) {
return OutputGuardrailResultAssert.assertThat(actual);
}
}
Loading

0 comments on commit b7f861a

Please sign in to comment.