Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: add queryMode option #101

Merged
merged 4 commits into from
Jan 10, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions .github/workflows/tests.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,8 @@ env:
OPENAI_API_KEY: ${{ secrets.OPENAI_API_KEY }}
OPENAI_ORGANIZATION: ${{ secrets.OPENAI_ORGANIZATION }}
QSTASH_TOKEN: ${{ secrets.QSTASH_TOKEN }}
HYBRID_EMBEDDING_UPSTASH_VECTOR_REST_URL: ${{ secrets.HYBRID_EMBEDDING_UPSTASH_VECTOR_REST_URL }}
HYBRID_EMBEDDING_UPSTASH_VECTOR_REST_TOKEN: ${{ secrets.HYBRID_EMBEDDING_UPSTASH_VECTOR_REST_TOKEN }}

jobs:
test:
Expand Down
Binary file modified bun.lockb
Binary file not shown.
6 changes: 3 additions & 3 deletions examples/nextjs/chat-to-website/ci.test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -69,7 +69,7 @@ async function resetResources() {
test(
"should invoke chat",
async () => {
await resetResources();
// await resetResources();
console.log("reset resources");

await invokeLoadPage();
Expand All @@ -88,8 +88,8 @@ test(
console.log(result);

const lowerCaseResult = result.toLowerCase();
expect(lowerCaseResult.includes("foo")).toBeTrue();
expect(lowerCaseResult.includes("bar")).toBeFalse();
expect(lowerCaseResult.includes("foo")).toBeFalse();
expect(lowerCaseResult.includes("bar")).toBeTrue();
},
{ timeout: 20_000 }
);
3 changes: 1 addition & 2 deletions package.json
Original file line number Diff line number Diff line change
Expand Up @@ -74,7 +74,7 @@
"@langchain/community": "^0.3.4",
"@langchain/core": "^0.2.9",
"@langchain/mistralai": "^0.0.28",
"@upstash/vector": "^1.1.3",
CahidArda marked this conversation as resolved.
Show resolved Hide resolved
"@upstash/vector": "^1.2.0",
"ai": "^3.1.1",
"cheerio": "^1.0.0-rc.12",
"d3-dsv": "^3.0.1",
Expand All @@ -89,7 +89,6 @@
"@langchain/openai": "^0.2.8",
"@upstash/ratelimit": "^1 || ^2",
"@upstash/redis": "^1.34.0",
"@upstash/vector": "^1.1.5",
"react": "^18 || ^19",
"react-dom": "^18 || ^19"
}
Expand Down
1 change: 1 addition & 0 deletions src/context-service/index.ts
Original file line number Diff line number Diff line change
Expand Up @@ -89,6 +89,7 @@ export class ContextService {
topK: optionsWithDefault.topK,
namespace: optionsWithDefault.namespace,
contextFilter: optionsWithDefault.contextFilter,
queryMode: optionsWithDefault.queryMode,
});

// Log the result, which will be captured by the outer traceable
Expand Down
5 changes: 4 additions & 1 deletion src/database.ts
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import type { WebBaseLoaderParams } from "@langchain/community/document_loaders/web/cheerio";
import type { Index } from "@upstash/vector";
import type { Index, QueryMode } from "@upstash/vector";
import type { RecursiveCharacterTextSplitterParams } from "langchain/text_splitter";
import { nanoid } from "nanoid";
import { DEFAULT_SIMILARITY_THRESHOLD, DEFAULT_TOP_K } from "./constants";
Expand Down Expand Up @@ -74,6 +74,7 @@ export type VectorPayload = {
topK?: number;
namespace?: string;
contextFilter?: string;
queryMode?: QueryMode;
};

export type ResetOptions = {
Expand Down Expand Up @@ -108,6 +109,7 @@ export class Database {
topK = DEFAULT_TOP_K,
namespace,
contextFilter,
queryMode,
}: VectorPayload): Promise<{ data: string; id: string; metadata: TMetadata }[]> {
const index = this.index;
const result = await index.query<Record<string, string>>(
Expand All @@ -117,6 +119,7 @@ export class Database {
includeData: true,
includeMetadata: true,
...(typeof contextFilter === "string" && { filter: contextFilter }),
queryMode,
},
{ namespace }
);
Expand Down
44 changes: 43 additions & 1 deletion src/rag-chat.test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@ import { ChatOpenAI } from "@langchain/openai";
import { openai } from "@ai-sdk/openai";
import { Ratelimit } from "@upstash/ratelimit";
import { Redis } from "@upstash/redis";
import { Index } from "@upstash/vector";
import { Index, QueryMode } from "@upstash/vector";
import { LangChainAdapter, StreamingTextResponse } from "ai";
import {
afterAll,
Expand Down Expand Up @@ -49,6 +49,11 @@ describe("RAG Chat with advance configs and direct instances", () => {
url: process.env.UPSTASH_REDIS_REST_URL!,
});

const hybridVector = new Index({
url: process.env.HYBRID_EMBEDDING_UPSTASH_VECTOR_REST_URL!,
token: process.env.HYBRID_EMBEDDING_UPSTASH_VECTOR_REST_TOKEN!,
});

const ragChat = new RAGChat({
model: upstashOpenai("gpt-3.5-turbo"),
vector,
Expand All @@ -69,6 +74,7 @@ describe("RAG Chat with advance configs and direct instances", () => {
await vector.reset({ namespace });
await vector.deleteNamespace(namespace);
await redis.flushdb();
await hybridVector.reset({ all: true });
});

test("should get result without streaming", async () => {
Expand All @@ -89,6 +95,42 @@ describe("RAG Chat with advance configs and direct instances", () => {
"Paris, the capital of France, is renowned for its iconic landmark, the Eiffel Tower, which was completed in 1889 and stands at 330 meters tall."
);
});

test("should retrieve with query mode", async () => {
const ragChat = new RAGChat({
vector: hybridVector,
streaming: true,
model: upstash("meta-llama/Meta-Llama-3-8B-Instruct"),
});

await ragChat.context.add({
type: "text",
data: "foo is bar",
});
await ragChat.context.add({
type: "text",
data: "foo is zed",
});
await awaitUntilIndexed(hybridVector);

const result = await ragChat.chat<{ unit: string }>("what is foo or bar?", {
topK: 1,
similarityThreshold: 0,
queryMode: QueryMode.SPARSE,
onContextFetched(context) {
expect(context.length).toBe(1);
return context;
},
});

expect(result.context).toEqual([
{
data: "foo is bar",
id: expect.any(String) as string,
metadata: undefined,
},
]);
});
});

describe("RAG Chat with ratelimit", () => {
Expand Down
1 change: 1 addition & 0 deletions src/rag-chat.ts
Original file line number Diff line number Diff line change
Expand Up @@ -292,6 +292,7 @@ export class RAGChat {
? DEFAULT_PROMPT_WITHOUT_RAG
: (options?.promptFn ?? this.config.prompt),
contextFilter: options?.contextFilter ?? undefined,
queryMode: options?.queryMode ?? undefined,
};
}
}
10 changes: 9 additions & 1 deletion src/types.ts
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@ import type { ChatOpenAI } from "@langchain/openai";
import type { openai } from "@ai-sdk/openai";
import type { Ratelimit } from "@upstash/ratelimit";
import type { Redis } from "@upstash/redis";
import type { Index } from "@upstash/vector";
import type { Index, QueryMode } from "@upstash/vector";
import type { CustomPrompt } from "./rag-chat";
import type { ChatMistralAI } from "@langchain/mistralai";
import type { ChatAnthropic } from "@langchain/anthropic";
Expand Down Expand Up @@ -92,6 +92,14 @@ export type ChatOptions = {
* https://upstash.com/docs/vector/features/filtering#metadata-filtering
*/
contextFilter?: string;

/**
* Query mode to use when querying a hybrid index.
*
* This is useful if your index is a hybrid index and you want to query the
* sparse or dense part when you pass `data`.
*/
queryMode?: QueryMode;
} & CommonChatAndRAGOptions;

export type PrepareChatResult = { data: string; id: string; metadata: unknown }[];
Expand Down
Loading