diff --git a/libs/langchain-community/src/chat_models/ibm.ts b/libs/langchain-community/src/chat_models/ibm.ts index 399e6d4d2909..8c329c70a9af 100644 --- a/libs/langchain-community/src/chat_models/ibm.ts +++ b/libs/langchain-community/src/chat_models/ibm.ts @@ -17,11 +17,11 @@ import { BaseLanguageModelInput, FunctionDefinition, StructuredOutputMethodOptions, - type BaseLanguageModelCallOptions, } from "@langchain/core/language_models/base"; import { CallbackManagerForLLMRun } from "@langchain/core/callbacks/manager"; import { BaseChatModel, + BaseChatModelCallOptions, BindToolsInput, LangSmithParams, type BaseChatModelParams, @@ -41,7 +41,6 @@ import { TextChatResultChoice, TextChatResultMessage, TextChatToolCall, - TextChatToolChoiceTool, TextChatUsage, } from "@ibm-cloud/watsonx-ai/dist/watsonx-ai-ml/vml_v1.js"; import { WatsonXAI } from "@ibm-cloud/watsonx-ai"; @@ -80,14 +79,14 @@ export interface WatsonxDeltaStream { } export interface WatsonxCallParams - extends Partial> { + extends Partial> { maxRetries?: number; } export interface WatsonxCallOptionsChat - extends Omit, + extends Omit, WatsonxCallParams { promptIndex?: number; - tool_choice?: TextChatToolChoiceTool; + tool_choice?: TextChatParameterTools | string | "auto" | "any"; } type ChatWatsonxToolType = BindToolsInput | TextChatParameterTools; @@ -309,6 +308,29 @@ function _convertDeltaToMessageChunk( return null; } +function _convertToolChoiceToWatsonxToolChoice( + toolChoice: TextChatParameterTools | string | "auto" | "any" +) { + if (typeof toolChoice === "string") { + if (toolChoice === "any" || toolChoice === "required") { + return { toolChoiceOption: "required" }; + } else if (toolChoice === "auto" || toolChoice === "none") { + return { toolChoiceOption: toolChoice }; + } else { + return { + toolChoice: { + type: "function", + function: { name: toolChoice }, + }, + }; + } + } else if ("type" in toolChoice) return { toolChoice }; + else + throw new Error( + `Unrecognized tool_choice type. Expected string or TextChatParameterTools. Recieved ${toolChoice}` + ); +} + export class ChatWatsonx< CallOptions extends WatsonxCallOptionsChat = WatsonxCallOptionsChat > @@ -459,7 +481,7 @@ export class ChatWatsonx< } invocationParams(options: this["ParsedCallOptions"]) { - return { + const params = { maxTokens: options.maxTokens ?? this.maxTokens, temperature: options?.temperature ?? this.temperature, timeLimit: options?.timeLimit ?? this.timeLimit, @@ -472,10 +494,12 @@ export class ChatWatsonx< tools: options.tools ? _convertToolToWatsonxTool(options.tools) : undefined, - toolChoice: options.tool_choice, responseFormat: options.responseFormat, - toolChoiceOption: options.toolChoiceOption, }; + const toolChoiceResult = options.tool_choice + ? _convertToolChoiceToWatsonxToolChoice(options.tool_choice) + : {}; + return { ...params, ...toolChoiceResult }; } override bindTools( @@ -562,7 +586,7 @@ export class ChatWatsonx< .map(([_, value]) => value); return { generations, llmOutput: { tokenUsage } }; } else { - const params: Omit = { + const params = { ...this.invocationParams(options), ...this.scopeId(), }; @@ -576,7 +600,6 @@ export class ChatWatsonx< messages: watsonxMessages, }); const { result } = await this.completionWithRetry(callback, options); - const generations: ChatGeneration[] = []; for (const part of result.choices) { const generation: ChatGeneration = { @@ -623,10 +646,13 @@ export class ChatWatsonx< }); const stream = await this.completionWithRetry(callback, options); let defaultRole; + let usage: TextChatUsage | undefined; + let currentCompletion = 0; for await (const chunk of stream) { if (options.signal?.aborted) { throw new Error("AbortError"); } + if (chunk?.data?.usage) usage = chunk.data.usage; const { data } = chunk; const choice = data.choices[0] as TextChatResultChoice & Record<"delta", TextChatResultMessage>; @@ -638,7 +664,7 @@ export class ChatWatsonx< if (!delta) { continue; } - + currentCompletion = choice.index ?? 0; const newTokenIndices = { prompt: options.promptIndex ?? 0, completion: choice.index ?? 0, @@ -682,6 +708,26 @@ export class ChatWatsonx< { chunk: generationChunk } ); } + + const generationChunk = new ChatGenerationChunk({ + message: new AIMessageChunk({ + content: "", + response_metadata: { + usage, + }, + usage_metadata: { + input_tokens: usage?.prompt_tokens ?? 0, + output_tokens: usage?.completion_tokens ?? 0, + total_tokens: usage?.total_tokens ?? 0, + }, + }), + text: "", + generationInfo: { + prompt: options.promptIndex ?? 0, + completion: currentCompletion ?? 0, + }, + }); + yield generationChunk; } /** @ignore */ @@ -760,7 +806,7 @@ export class ChatWatsonx< }, ], // Ideally that would be set to required but this is not supported yet - toolChoice: { + tool_choice: { type: "function", function: { name: functionName, @@ -796,7 +842,7 @@ export class ChatWatsonx< }, ], // Ideally that would be set to required but this is not supported yet - toolChoice: { + tool_choice: { type: "function", function: { name: functionName, diff --git a/libs/langchain-community/src/chat_models/tests/ibm.int.test.ts b/libs/langchain-community/src/chat_models/tests/ibm.int.test.ts index 2f1d118d92a4..be8d6615a402 100644 --- a/libs/langchain-community/src/chat_models/tests/ibm.int.test.ts +++ b/libs/langchain-community/src/chat_models/tests/ibm.int.test.ts @@ -150,7 +150,7 @@ describe("Tests for chat", () => { controller.abort(); return res; }).rejects.toThrow(); - }, 5000); + }); }); describe("Test ChatWatsonx invoke and generate with stream mode", () => { @@ -357,7 +357,7 @@ describe("Tests for chat", () => { controller.abort(); return res; }).rejects.toThrow(); - }, 5000); + }); }); describe("Test ChatWatsonx stream", () => { @@ -415,7 +415,7 @@ describe("Tests for chat", () => { } expect(hasEntered).toBe(true); }).rejects.toThrow(); - }, 5000); + }); test("Token count and response equality", async () => { let generation = ""; const service = new ChatWatsonx({ diff --git a/libs/langchain-community/src/chat_models/tests/ibm.standard.int.test.ts b/libs/langchain-community/src/chat_models/tests/ibm.standard.int.test.ts index 545ed3c06fa9..68b967d972b7 100644 --- a/libs/langchain-community/src/chat_models/tests/ibm.standard.int.test.ts +++ b/libs/langchain-community/src/chat_models/tests/ibm.standard.int.test.ts @@ -26,7 +26,7 @@ class ChatWatsonxStandardIntegrationTests extends ChatModelIntegrationTests< chatModelHasToolCalling: true, chatModelHasStructuredOutput: true, constructorArgs: { - model: "mistralai/mistral-large", + model: "meta-llama/llama-3-1-70b-instruct", version: "2024-05-31", serviceUrl: process.env.WATSONX_AI_SERVICE_URL ?? "testString", projectId: process.env.WATSONX_AI_PROJECT_ID ?? "testString", diff --git a/libs/langchain-community/src/document_compressors/ibm.ts b/libs/langchain-community/src/document_compressors/ibm.ts index 348f60685480..026219cc8fa8 100644 --- a/libs/langchain-community/src/document_compressors/ibm.ts +++ b/libs/langchain-community/src/document_compressors/ibm.ts @@ -115,6 +115,9 @@ export class WatsonxRerank ...this.scopeId(), inputs, query, + parameters: { + truncate_input_tokens: this.truncateInputTokens, + }, }) ); const resultDocuments = result.results.map(({ index, score }) => { diff --git a/libs/langchain-community/src/document_compressors/tests/ibm.int.test.ts b/libs/langchain-community/src/document_compressors/tests/ibm.int.test.ts index e65ea9e1eff3..6994bcec7c1a 100644 --- a/libs/langchain-community/src/document_compressors/tests/ibm.int.test.ts +++ b/libs/langchain-community/src/document_compressors/tests/ibm.int.test.ts @@ -40,6 +40,25 @@ describe("Integration tests on WatsonxRerank", () => { expect(typeof item.metadata.relevanceScore).toBe("number") ); }); + + test("Basic call with truncation", async () => { + const instance = new WatsonxRerank({ + model: "cross-encoder/ms-marco-minilm-l-12-v2", + serviceUrl: process.env.WATSONX_AI_SERVICE_URL as string, + version: "2024-05-31", + projectId: process.env.WATSONX_AI_PROJECT_ID ?? "testString", + truncateInputTokens: 512, + }); + const longerDocs: Document[] = docs.map((item) => ({ + pageContent: item.pageContent.repeat(100), + metadata: {}, + })); + const result = await instance.compressDocuments(longerDocs, query); + expect(result.length).toBe(docs.length); + result.forEach((item) => + expect(typeof item.metadata.relevanceScore).toBe("number") + ); + }); }); describe(".rerank() method", () => { @@ -57,24 +76,42 @@ describe("Integration tests on WatsonxRerank", () => { expect(item.input).toBeUndefined(); }); }); - }); - test("Basic call with options", async () => { - const instance = new WatsonxRerank({ - model: "cross-encoder/ms-marco-minilm-l-12-v2", - serviceUrl: process.env.WATSONX_AI_SERVICE_URL as string, - version: "2024-05-31", - projectId: process.env.WATSONX_AI_PROJECT_ID ?? "testString", - }); - const result = await instance.rerank(docs, query, { - returnOptions: { - topN: 3, - inputs: true, - }, + test("Basic call with options", async () => { + const instance = new WatsonxRerank({ + model: "cross-encoder/ms-marco-minilm-l-12-v2", + serviceUrl: process.env.WATSONX_AI_SERVICE_URL as string, + version: "2024-05-31", + projectId: process.env.WATSONX_AI_PROJECT_ID ?? "testString", + }); + const result = await instance.rerank(docs, query, { + returnOptions: { + topN: 3, + inputs: true, + }, + }); + expect(result.length).toBe(3); + result.forEach((item) => { + expect(typeof item.relevanceScore).toBe("number"); + expect(item.input).toBeDefined(); + }); }); - expect(result.length).toBe(3); - result.forEach((item) => { - expect(typeof item.relevanceScore).toBe("number"); - expect(item.input).toBeDefined(); + test("Basic call with truncation", async () => { + const instance = new WatsonxRerank({ + model: "cross-encoder/ms-marco-minilm-l-12-v2", + serviceUrl: process.env.WATSONX_AI_SERVICE_URL as string, + version: "2024-05-31", + projectId: process.env.WATSONX_AI_PROJECT_ID ?? "testString", + }); + const longerDocs = docs.map((item) => ({ + pageContent: item.pageContent.repeat(100), + })); + const result = await instance.rerank(longerDocs, query, { + truncateInputTokens: 512, + }); + result.forEach((item) => { + expect(typeof item.relevanceScore).toBe("number"); + expect(item.input).toBeUndefined(); + }); }); }); }); diff --git a/libs/langchain-community/src/llms/ibm.ts b/libs/langchain-community/src/llms/ibm.ts index a0e8a292f0bf..5647e2f3ae9f 100644 --- a/libs/langchain-community/src/llms/ibm.ts +++ b/libs/langchain-community/src/llms/ibm.ts @@ -259,7 +259,9 @@ export class WatsonxLLM< input: string, options: this["ParsedCallOptions"], stream: true - ): Promise>; + ): Promise< + AsyncIterable> + >; private async generateSingleMessage( input: string, @@ -294,14 +296,16 @@ export class WatsonxLLM< input, }, }, + returnObject: true, }) : await this.service.generateTextStream({ input, parameters, ...this.scopeId(), ...requestOptions, + returnObject: true, }); - return textStream as unknown as AsyncIterable; + return textStream; } else { const textGenerationPromise = idOrName ? this.service.deploymentGenerateText({ @@ -367,7 +371,7 @@ export class WatsonxLLM< async _generate( prompts: string[], options: this["ParsedCallOptions"], - _runManager?: CallbackManagerForLLMRun + runManager?: CallbackManagerForLLMRun ): Promise { const tokenUsage: TokenUsage = { generated_token_count: 0, @@ -379,70 +383,38 @@ export class WatsonxLLM< if (options.signal?.aborted) { throw new Error("AbortError"); } - const callback = () => - this.generateSingleMessage(prompt, options, true); - type ReturnMessage = ReturnType; - const stream = await this.completionWithRetry( - callback, - options - ); + const stream = this._streamResponseChunks(prompt, options); + const geneartionsArray: GenerationInfo[] = []; - const responseChunk: ResponseChunk = { - id: 0, - event: "", - data: { - results: [], - }, - }; - const messages: ResponseChunk[] = []; - type ResponseChunkKeys = keyof ResponseChunk; for await (const chunk of stream) { - if (chunk.length > 0) { - const index = chunk.indexOf(": "); - const [key, value] = [ - chunk.substring(0, index) as ResponseChunkKeys, - chunk.substring(index + 2), - ]; - if (key === "id") { - responseChunk[key] = Number(value); - } else if (key === "event") { - responseChunk[key] = String(value); - } else { - responseChunk[key] = JSON.parse(value); - } - } else if (chunk.length === 0) { - messages.push(JSON.parse(JSON.stringify(responseChunk))); - Object.assign(responseChunk, { id: 0, event: "", data: {} }); - } - } - - const geneartionsArray: GenerationInfo[] = []; - for (const message of messages) { - message.data.results.forEach((item, index) => { - const generationInfo: GenerationInfo = { - text: "", - stop_reason: "", - generated_token_count: 0, - input_token_count: 0, - }; - void _runManager?.handleLLMNewToken(item.generated_text ?? "", { + const completion = chunk?.generationInfo?.completion ?? 0; + const generationInfo: GenerationInfo = { + text: "", + stop_reason: "", + generated_token_count: 0, + input_token_count: 0, + }; + geneartionsArray[completion] ??= generationInfo; + geneartionsArray[completion].generated_token_count = + chunk?.generationInfo?.usage_metadata.generated_token_count ?? 0; + geneartionsArray[completion].input_token_count += + chunk?.generationInfo?.usage_metadata.input_token_count ?? 0; + geneartionsArray[completion].stop_reason = + chunk?.generationInfo?.stop_reason; + geneartionsArray[completion].text += chunk.text; + if (chunk.text) + void runManager?.handleLLMNewToken(chunk.text, { prompt: promptIdx, - completion: 1, + completion: 0, }); - geneartionsArray[index] ??= generationInfo; - geneartionsArray[index].generated_token_count = - item.generated_token_count; - geneartionsArray[index].input_token_count += - item.input_token_count; - geneartionsArray[index].stop_reason = item.stop_reason; - geneartionsArray[index].text += item.generated_text; - }); } + return geneartionsArray.map((item) => { const { text, ...rest } = item; - tokenUsage.generated_token_count += rest.generated_token_count; + tokenUsage.generated_token_count = rest.generated_token_count; tokenUsage.input_token_count += rest.input_token_count; + return { text, generationInfo: rest, @@ -527,35 +499,23 @@ export class WatsonxLLM< throw new Error("AbortError"); } - type Keys = keyof typeof responseChunk; - if (chunk.length > 0) { - const index = chunk.indexOf(": "); - const [key, value] = [ - chunk.substring(0, index) as Keys, - chunk.substring(index + 2), - ]; - if (key === "id") { - responseChunk[key] = Number(value); - } else if (key === "event") { - responseChunk[key] = String(value); - } else { - responseChunk[key] = JSON.parse(value); - } - } else if ( - chunk.length === 0 && - responseChunk.data?.results?.length > 0 - ) { - for (const item of responseChunk.data.results) { - yield new GenerationChunk({ - text: item.generated_text, - generationInfo: { + for (const [index, item] of chunk.data.results.entries()) { + yield new GenerationChunk({ + text: item.generated_text, + generationInfo: { + stop_reason: item.stop_reason, + completion: index, + usage_metadata: { + generated_token_count: item.generated_token_count, + input_token_count: item.input_token_count, stop_reason: item.stop_reason, }, - }); - await runManager?.handleLLMNewToken(item.generated_text ?? ""); - } - Object.assign(responseChunk, { id: 0, event: "", data: {} }); + }, + }); + if (item.generated_text) + void runManager?.handleLLMNewToken(item.generated_text); } + Object.assign(responseChunk, { id: 0, event: "", data: {} }); } } diff --git a/libs/langchain-community/src/llms/tests/ibm.int.test.ts b/libs/langchain-community/src/llms/tests/ibm.int.test.ts index dfeebedd39e2..369b657fb4ca 100644 --- a/libs/langchain-community/src/llms/tests/ibm.int.test.ts +++ b/libs/langchain-community/src/llms/tests/ibm.int.test.ts @@ -172,7 +172,6 @@ describe("Text generation", () => { let usedTokens = 0; const model = new WatsonxLLM({ model: "ibm/granite-13b-chat-v2", - maxConcurrency: 1, version: "2024-05-31", serviceUrl: process.env.WATSONX_AI_SERVICE_URL as string, projectId: process.env.WATSONX_AI_PROJECT_ID, @@ -190,7 +189,7 @@ describe("Text generation", () => { }), }); - const res = await model.invoke(" Print hello world?"); + const res = await model.invoke("Print hello world?"); expect(countedTokens).toBe(usedTokens); expect(res).toBe(streamedText); });