From 8f63f8d6d08c0d12262f6a183f8892d71add0523 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Filip=20=C5=BBmijewski?= Date: Wed, 27 Nov 2024 14:37:04 +0100 Subject: [PATCH] Change compressDocuments() method, add tests regarding changes" --- .../src/document_compressors/ibm.ts | 3 + .../tests/ibm.int.test.ts | 71 ++++++++++++++----- 2 files changed, 57 insertions(+), 17 deletions(-) diff --git a/libs/langchain-community/src/document_compressors/ibm.ts b/libs/langchain-community/src/document_compressors/ibm.ts index 348f60685480..026219cc8fa8 100644 --- a/libs/langchain-community/src/document_compressors/ibm.ts +++ b/libs/langchain-community/src/document_compressors/ibm.ts @@ -115,6 +115,9 @@ export class WatsonxRerank ...this.scopeId(), inputs, query, + parameters: { + truncate_input_tokens: this.truncateInputTokens, + }, }) ); const resultDocuments = result.results.map(({ index, score }) => { diff --git a/libs/langchain-community/src/document_compressors/tests/ibm.int.test.ts b/libs/langchain-community/src/document_compressors/tests/ibm.int.test.ts index e65ea9e1eff3..6994bcec7c1a 100644 --- a/libs/langchain-community/src/document_compressors/tests/ibm.int.test.ts +++ b/libs/langchain-community/src/document_compressors/tests/ibm.int.test.ts @@ -40,6 +40,25 @@ describe("Integration tests on WatsonxRerank", () => { expect(typeof item.metadata.relevanceScore).toBe("number") ); }); + + test("Basic call with truncation", async () => { + const instance = new WatsonxRerank({ + model: "cross-encoder/ms-marco-minilm-l-12-v2", + serviceUrl: process.env.WATSONX_AI_SERVICE_URL as string, + version: "2024-05-31", + projectId: process.env.WATSONX_AI_PROJECT_ID ?? "testString", + truncateInputTokens: 512, + }); + const longerDocs: Document[] = docs.map((item) => ({ + pageContent: item.pageContent.repeat(100), + metadata: {}, + })); + const result = await instance.compressDocuments(longerDocs, query); + expect(result.length).toBe(docs.length); + result.forEach((item) => + expect(typeof item.metadata.relevanceScore).toBe("number") + ); + }); }); describe(".rerank() method", () => { @@ -57,24 +76,42 @@ describe("Integration tests on WatsonxRerank", () => { expect(item.input).toBeUndefined(); }); }); - }); - test("Basic call with options", async () => { - const instance = new WatsonxRerank({ - model: "cross-encoder/ms-marco-minilm-l-12-v2", - serviceUrl: process.env.WATSONX_AI_SERVICE_URL as string, - version: "2024-05-31", - projectId: process.env.WATSONX_AI_PROJECT_ID ?? "testString", - }); - const result = await instance.rerank(docs, query, { - returnOptions: { - topN: 3, - inputs: true, - }, + test("Basic call with options", async () => { + const instance = new WatsonxRerank({ + model: "cross-encoder/ms-marco-minilm-l-12-v2", + serviceUrl: process.env.WATSONX_AI_SERVICE_URL as string, + version: "2024-05-31", + projectId: process.env.WATSONX_AI_PROJECT_ID ?? "testString", + }); + const result = await instance.rerank(docs, query, { + returnOptions: { + topN: 3, + inputs: true, + }, + }); + expect(result.length).toBe(3); + result.forEach((item) => { + expect(typeof item.relevanceScore).toBe("number"); + expect(item.input).toBeDefined(); + }); }); - expect(result.length).toBe(3); - result.forEach((item) => { - expect(typeof item.relevanceScore).toBe("number"); - expect(item.input).toBeDefined(); + test("Basic call with truncation", async () => { + const instance = new WatsonxRerank({ + model: "cross-encoder/ms-marco-minilm-l-12-v2", + serviceUrl: process.env.WATSONX_AI_SERVICE_URL as string, + version: "2024-05-31", + projectId: process.env.WATSONX_AI_PROJECT_ID ?? "testString", + }); + const longerDocs = docs.map((item) => ({ + pageContent: item.pageContent.repeat(100), + })); + const result = await instance.rerank(longerDocs, query, { + truncateInputTokens: 512, + }); + result.forEach((item) => { + expect(typeof item.relevanceScore).toBe("number"); + expect(item.input).toBeUndefined(); + }); }); }); });