Skip to content

Commit

Permalink
Merge pull request #40 from cescoffier/chatbot-rag-demo
Browse files Browse the repository at this point in the history
Demo of a chat bot using RAG.
  • Loading branch information
geoand authored Nov 20, 2023
2 parents 4fbdfad + 0b11c98 commit 02d4244
Show file tree
Hide file tree
Showing 21 changed files with 2,011 additions and 1 deletion.
1 change: 1 addition & 0 deletions pom.xml
Original file line number Diff line number Diff line change
Expand Up @@ -104,6 +104,7 @@
<module>samples/email-a-poem</module>
<module>samples/review-triage</module>
<module>samples/fraud-detection</module>
<module>samples/chatbot</module>
</modules>
</profile>
</profiles>
Expand Down
151 changes: 151 additions & 0 deletions samples/chatbot/pom.xml
Original file line number Diff line number Diff line change
@@ -0,0 +1,151 @@
<?xml version="1.0" encoding="UTF-8"?>
<project xmlns="http://maven.apache.org/POM/4.0.0"
xmlns:xsi="http://www.w3.org/2001/XMLSchema-instance"
xsi:schemaLocation="http://maven.apache.org/POM/4.0.0 http://maven.apache.org/xsd/maven-4.0.0.xsd">
<modelVersion>4.0.0</modelVersion>

<parent>
<groupId>io.quarkiverse.langchain4j</groupId>
<artifactId>quarkus-langchain4j-parent</artifactId>
<version>999-SNAPSHOT</version>
<relativePath>../..</relativePath>
</parent>

<artifactId>quarkus-langchain4j-sample-chatbot</artifactId>
<name>Quarkus langchain4j - Sample - Chatbot &amp; RAG</name>

<dependencies>
<dependency>
<groupId>io.quarkus</groupId>
<artifactId>quarkus-resteasy-reactive-jackson</artifactId>
</dependency>
<dependency>
<groupId>io.quarkus</groupId>
<artifactId>quarkus-websockets</artifactId>
</dependency>
<dependency>
<groupId>io.quarkiverse.langchain4j</groupId>
<artifactId>quarkus-langchain4j-openai</artifactId>
<version>${project.version}</version>
</dependency>
<dependency>
<groupId>io.quarkiverse.langchain4j</groupId>
<artifactId>quarkus-langchain4j-redis</artifactId>
<version>${project.version}</version>
</dependency>

<!-- UI -->
<dependency>
<groupId>org.mvnpm</groupId>
<artifactId>importmap</artifactId>
<version>1.0.8</version>
</dependency>
<dependency>
<groupId>org.mvnpm.at.mvnpm</groupId>
<artifactId>vaadin-webcomponents</artifactId>
<version>24.2.1</version>
<scope>runtime</scope>
</dependency>
<dependency>
<groupId>org.mvnpm</groupId>
<artifactId>es-module-shims</artifactId>
<scope>runtime</scope>
<version>1.8.2</version>
</dependency>
<dependency>
<groupId>org.mvnpm</groupId>
<artifactId>wc-chatbot</artifactId>
<version>0.1.2</version>
<scope>runtime</scope>
</dependency>
</dependencies>
<build>
<plugins>
<plugin>
<groupId>io.quarkus</groupId>
<artifactId>quarkus-maven-plugin</artifactId>
<version>${quarkus.version}</version>
<executions>
<execution>
<goals>
<goal>build</goal>
</goals>
</execution>
</executions>
</plugin>
<plugin>
<artifactId>maven-compiler-plugin</artifactId>
</plugin>
<plugin>
<artifactId>maven-surefire-plugin</artifactId>
<version>3.2.2</version>
<configuration>
<systemPropertyVariables>
<java.util.logging.manager>org.jboss.logmanager.LogManager</java.util.logging.manager>
<maven.home>${maven.home}</maven.home>
</systemPropertyVariables>
</configuration>
</plugin>
</plugins>
</build>

<profiles>
<profile>
<id>native</id>
<activation>
<property>
<name>native</name>
</property>
</activation>
<build>
<plugins>
<plugin>
<artifactId>maven-failsafe-plugin</artifactId>
<version>3.2.2</version>
<executions>
<execution>
<goals>
<goal>integration-test</goal>
<goal>verify</goal>
</goals>
<configuration>
<systemPropertyVariables>
<native.image.path>${project.build.directory}/${project.build.finalName}-runner</native.image.path>
<java.util.logging.manager>org.jboss.logmanager.LogManager</java.util.logging.manager>
<maven.home>${maven.home}</maven.home>
</systemPropertyVariables>
</configuration>
</execution>
</executions>
</plugin>
</plugins>
</build>
<properties>
<quarkus.package.type>native</quarkus.package.type>
</properties>
</profile>

<profile>
<id>mvnpm</id>
<activation>
<activeByDefault>true</activeByDefault>
</activation>
<repositories>
<repository>
<id>central</id>
<name>central</name>
<url>https://repo.maven.apache.org/maven2</url>
</repository>
<repository>
<snapshots>
<enabled>false</enabled>
</snapshots>
<id>mvnpm.org</id>
<name>mvnpm</name>
<url>https://repo.mvnpm.org/maven2</url>
</repository>
</repositories>
</profile>
</profiles>

</project>
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
package io.quarkiverse.langchain4j.sample.chatbot;

import dev.langchain4j.service.MemoryId;
import dev.langchain4j.service.SystemMessage;
import dev.langchain4j.service.UserMessage;
import io.quarkiverse.langchain4j.RegisterAiService;

@RegisterAiService(chatMemoryProviderSupplier = RegisterAiService.BeanChatMemoryProviderSupplier.class, retrieverSupplier = RegisterAiService.BeanRetrieverSupplier.class)
public interface Bot {

@SystemMessage("""
You are an AI answering questions about financial products.
Your response must be polite, use the same language as the question, and be relevant to the question.
When you don't know, respond that you don't know the answer and the bank will contact the customer directly.
Introduce yourself with: "Hello, I'm Bob, how can I help you?"
""")
String chat(@MemoryId Object session, @UserMessage String question);
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,50 @@
package io.quarkiverse.langchain4j.sample.chatbot;

import java.io.IOException;

import jakarta.inject.Inject;
import jakarta.websocket.*;
import jakarta.websocket.server.ServerEndpoint;

import io.smallrye.mutiny.infrastructure.Infrastructure;

@ServerEndpoint("/chatbot")
public class ChatBotWebSocket {

@Inject
Bot bot;

@Inject
ChatMemoryBean chatMemoryBean;

@OnOpen
public void onOpen(Session session) {
Infrastructure.getDefaultExecutor().execute(() -> {
String response = bot.chat(session, "hello");
try {
session.getBasicRemote().sendText(response);
} catch (IOException e) {
throw new RuntimeException(e);
}
});
}

@OnClose
void onClose(Session session) {
chatMemoryBean.clear(session);
}

@OnMessage
public void onMessage(String message, Session session) {
Infrastructure.getDefaultExecutor().execute(() -> {
String response = bot.chat(session, message);
try {
session.getBasicRemote().sendText(response);
} catch (IOException e) {
throw new RuntimeException(e);
}
});

}

}
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
package io.quarkiverse.langchain4j.sample.chatbot;

import java.util.Map;
import java.util.concurrent.ConcurrentHashMap;

import jakarta.enterprise.context.ApplicationScoped;

import dev.langchain4j.memory.ChatMemory;
import dev.langchain4j.memory.chat.ChatMemoryProvider;
import dev.langchain4j.memory.chat.MessageWindowChatMemory;

@ApplicationScoped
public class ChatMemoryBean implements ChatMemoryProvider {

private final Map<Object, ChatMemory> memories = new ConcurrentHashMap<>();

@Override
public ChatMemory get(Object memoryId) {
return memories.computeIfAbsent(memoryId, id -> MessageWindowChatMemory.builder()
.maxMessages(20)
.id(memoryId)
.build());
}

public void clear(Object session) {
memories.remove(session);
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,51 @@
package io.quarkiverse.langchain4j.sample.chatbot;

import jakarta.annotation.PostConstruct;
import jakarta.enterprise.context.ApplicationScoped;
import jakarta.ws.rs.GET;
import jakarta.ws.rs.Path;
import jakarta.ws.rs.Produces;

import org.mvnpm.importmap.Aggregator;

/**
* Dynamically create the import map
*/
@ApplicationScoped
@Path("/_importmap")
public class ImportmapResource {
private String importmap;

// See https://github.com/WICG/import-maps/issues/235
// This does not seem to be supported by browsers yet...
@GET
@Path("/dynamic.importmap")
@Produces("application/importmap+json")
public String importMap() {
return this.importmap;
}

@GET
@Path("/dynamic-importmap.js")
@Produces("application/javascript")
public String importMapJson() {
return JAVASCRIPT_CODE.formatted(this.importmap);
}

@PostConstruct
void init() {
Aggregator aggregator = new Aggregator();
// Add our own mappings
aggregator.addMapping("icons/", "/icons/");
aggregator.addMapping("components/", "/components/");
aggregator.addMapping("fonts/", "/fonts/");
this.importmap = aggregator.aggregateAsJson();
}

private static final String JAVASCRIPT_CODE = """
const im = document.createElement('script');
im.type = 'importmap';
im.textContent = JSON.stringify(%s);
document.currentScript.after(im);
""";
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,47 @@
package io.quarkiverse.langchain4j.sample.chatbot;

import static dev.langchain4j.data.document.splitter.DocumentSplitters.recursive;

import java.io.File;
import java.util.List;

import jakarta.enterprise.context.ApplicationScoped;
import jakarta.enterprise.event.Observes;
import jakarta.inject.Inject;

import dev.langchain4j.data.document.Document;
import dev.langchain4j.data.document.FileSystemDocumentLoader;
import dev.langchain4j.model.embedding.EmbeddingModel;
import dev.langchain4j.store.embedding.EmbeddingStoreIngestor;
import io.quarkiverse.langchain4j.redis.RedisEmbeddingStore;
import io.quarkus.runtime.StartupEvent;

@ApplicationScoped
public class IngestorExample {

/**
* The embedding store (the database).
* The bean is provided by the quarkus-langchain4j-redis extension.
*/
@Inject
RedisEmbeddingStore store;

/**
* The embedding model (how the vector of a document is computed).
* The bean is provided by the LLM (like openai) extension.
*/
@Inject
EmbeddingModel embeddingModel;

public void ingest(@Observes StartupEvent event) {
System.out.printf("Ingesting documents...%n");
List<Document> documents = FileSystemDocumentLoader.loadDocuments(new File("src/main/resources/catalog").toPath());
var ingestor = EmbeddingStoreIngestor.builder()
.embeddingStore(store)
.embeddingModel(embeddingModel)
.documentSplitter(recursive(500, 0))
.build();
ingestor.ingest(documents);
System.out.printf("Ingested %d documents.%n", documents.size());
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
package io.quarkiverse.langchain4j.sample.chatbot;

import java.util.List;

import jakarta.enterprise.context.ApplicationScoped;

import dev.langchain4j.data.segment.TextSegment;
import dev.langchain4j.model.embedding.EmbeddingModel;
import dev.langchain4j.retriever.EmbeddingStoreRetriever;
import dev.langchain4j.retriever.Retriever;
import io.quarkiverse.langchain4j.redis.RedisEmbeddingStore;

@ApplicationScoped
public class RetrieverExample implements Retriever<TextSegment> {

private final EmbeddingStoreRetriever retriever;

RetrieverExample(RedisEmbeddingStore store, EmbeddingModel model) {
retriever = EmbeddingStoreRetriever.from(store, model, 20);
}

@Override
public List<TextSegment> findRelevant(String s) {
return retriever.findRelevant(s);
}
}
Loading

0 comments on commit 02d4244

Please sign in to comment.