From 90fb95aaed891fe099a3740f008be882e975278f Mon Sep 17 00:00:00 2001 From: Lars Grammel Date: Thu, 9 Jan 2025 17:15:13 +0100 Subject: [PATCH] chore (provider-utils): switch to unified test server (#4330) --- .changeset/late-cups-dance.md | 7 + .../anthropic-messages-language-model.test.ts | 1308 +++++++++-------- .../src/fireworks-image-model.test.ts | 156 +- .../src/test/binary-test-server.test.ts | 186 --- .../src/test/binary-test-server.ts | 131 -- packages/provider-utils/src/test/index.ts | 1 - .../src/test/json-test-server.ts | 6 + .../src/test/streaming-test-server.ts | 6 + .../src/test/unified-test-server.ts | 156 +- .../src/replicate-image-model.test.ts | 7 +- 10 files changed, 868 insertions(+), 1096 deletions(-) create mode 100644 .changeset/late-cups-dance.md delete mode 100644 packages/provider-utils/src/test/binary-test-server.test.ts delete mode 100644 packages/provider-utils/src/test/binary-test-server.ts diff --git a/.changeset/late-cups-dance.md b/.changeset/late-cups-dance.md new file mode 100644 index 000000000000..9c48d4f5973f --- /dev/null +++ b/.changeset/late-cups-dance.md @@ -0,0 +1,7 @@ +--- +'@ai-sdk/provider-utils': patch +'@ai-sdk/fireworks': patch +'@ai-sdk/replicate': patch +--- + +chore (provider-utils): switch to unified test server diff --git a/packages/anthropic/src/anthropic-messages-language-model.test.ts b/packages/anthropic/src/anthropic-messages-language-model.test.ts index 445895923de5..8929fe798be7 100644 --- a/packages/anthropic/src/anthropic-messages-language-model.test.ts +++ b/packages/anthropic/src/anthropic-messages-language-model.test.ts @@ -1,8 +1,7 @@ import { LanguageModelV1Prompt } from '@ai-sdk/provider'; import { - JsonTestServer, - StreamingTestServer, convertReadableStreamToArray, + createTestServer, } from '@ai-sdk/provider-utils/test'; import { AnthropicAssistantMessage } from './anthropic-api-types'; import { createAnthropic } from './anthropic-provider'; @@ -14,87 +13,146 @@ const TEST_PROMPT: LanguageModelV1Prompt = [ const provider = createAnthropic({ apiKey: 'test-api-key' }); const model = provider('claude-3-haiku-20240307'); -describe('doGenerate', () => { - const server = new JsonTestServer('https://api.anthropic.com/v1/messages'); - - server.setupTestEnvironment(); - - function prepareJsonResponse({ - content = [{ type: 'text', text: '', cache_control: undefined }], - usage = { - input_tokens: 4, - output_tokens: 30, - }, - stopReason = 'end_turn', - id = 'msg_017TfcQ4AgGxKyBduUpqYPZn', - model = 'claude-3-haiku-20240307', - }: { - content?: AnthropicAssistantMessage['content']; - usage?: { - input_tokens: number; - output_tokens: number; - cache_creation_input_tokens?: number; - cache_read_input_tokens?: number; - }; - stopReason?: string; - id?: string; - model?: string; - }) { - server.responseBodyJson = { - id, - type: 'message', - role: 'assistant', - content, - model, - stop_reason: stopReason, - stop_sequence: null, - usage, - }; - } - - it('should extract text response', async () => { - prepareJsonResponse({ - content: [ - { type: 'text', text: 'Hello, World!', cache_control: undefined }, - ], - }); +describe('AnthropicMessagesLanguageModel', () => { + const server = createTestServer({ + 'https://api.anthropic.com/v1/messages': {}, + }); - const { text } = await provider('claude-3-haiku-20240307').doGenerate({ - inputFormat: 'prompt', - mode: { type: 'regular' }, - prompt: TEST_PROMPT, - }); + describe('doGenerate', () => { + function prepareJsonResponse({ + content = [{ type: 'text', text: '', cache_control: undefined }], + usage = { + input_tokens: 4, + output_tokens: 30, + }, + stopReason = 'end_turn', + id = 'msg_017TfcQ4AgGxKyBduUpqYPZn', + model = 'claude-3-haiku-20240307', + headers = {}, + }: { + content?: AnthropicAssistantMessage['content']; + usage?: { + input_tokens: number; + output_tokens: number; + cache_creation_input_tokens?: number; + cache_read_input_tokens?: number; + }; + stopReason?: string; + id?: string; + model?: string; + headers?: Record; + }) { + server.urls['https://api.anthropic.com/v1/messages'].response = { + type: 'json-value', + headers, + body: { + id, + type: 'message', + role: 'assistant', + content, + model, + stop_reason: stopReason, + stop_sequence: null, + usage, + }, + }; + } - expect(text).toStrictEqual('Hello, World!'); - }); + it('should extract text response', async () => { + prepareJsonResponse({ + content: [ + { type: 'text', text: 'Hello, World!', cache_control: undefined }, + ], + }); - it('should extract tool calls', async () => { - prepareJsonResponse({ - content: [ - { - type: 'text', - text: 'Some text\n\n', - cache_control: undefined, + const { text } = await provider('claude-3-haiku-20240307').doGenerate({ + inputFormat: 'prompt', + mode: { type: 'regular' }, + prompt: TEST_PROMPT, + }); + + expect(text).toStrictEqual('Hello, World!'); + }); + + it('should extract tool calls', async () => { + prepareJsonResponse({ + content: [ + { + type: 'text', + text: 'Some text\n\n', + cache_control: undefined, + }, + { + type: 'tool_use', + id: 'toolu_1', + name: 'test-tool', + input: { value: 'example value' }, + cache_control: undefined, + }, + ], + stopReason: 'tool_use', + }); + + const { toolCalls, finishReason, text } = await model.doGenerate({ + inputFormat: 'prompt', + mode: { + type: 'regular', + tools: [ + { + type: 'function', + name: 'test-tool', + parameters: { + type: 'object', + properties: { value: { type: 'string' } }, + required: ['value'], + additionalProperties: false, + $schema: 'http://json-schema.org/draft-07/schema#', + }, + }, + ], }, + prompt: TEST_PROMPT, + }); + + expect(toolCalls).toStrictEqual([ { - type: 'tool_use', - id: 'toolu_1', - name: 'test-tool', - input: { value: 'example value' }, - cache_control: undefined, + toolCallId: 'toolu_1', + toolCallType: 'function', + toolName: 'test-tool', + args: '{"value":"example value"}', }, - ], - stopReason: 'tool_use', + ]); + expect(text).toStrictEqual('Some text\n\n'); + expect(finishReason).toStrictEqual('tool-calls'); }); - const { toolCalls, finishReason, text } = await model.doGenerate({ - inputFormat: 'prompt', - mode: { - type: 'regular', - tools: [ + it('should support object-tool mode', async () => { + prepareJsonResponse({ + content: [ + { + type: 'text', + text: 'Some text\n\n', + cache_control: undefined, + }, { + type: 'tool_use', + id: 'toolu_1', + name: 'json', + input: { value: 'example value' }, + cache_control: undefined, + }, + ], + stopReason: 'tool_use', + }); + + const { toolCalls, finishReason } = await model.doGenerate({ + inputFormat: 'prompt', + mode: { + type: 'object-tool', + tool: { type: 'function', - name: 'test-tool', + name: 'json', + description: 'Respond with a JSON object.', parameters: { type: 'object', properties: { value: { type: 'string' } }, @@ -103,196 +161,171 @@ describe('doGenerate', () => { $schema: 'http://json-schema.org/draft-07/schema#', }, }, - ], - }, - prompt: TEST_PROMPT, - }); - - expect(toolCalls).toStrictEqual([ - { - toolCallId: 'toolu_1', - toolCallType: 'function', - toolName: 'test-tool', - args: '{"value":"example value"}', - }, - ]); - expect(text).toStrictEqual('Some text\n\n'); - expect(finishReason).toStrictEqual('tool-calls'); - }); - - it('should support object-tool mode', async () => { - prepareJsonResponse({ - content: [ - { - type: 'text', - text: 'Some text\n\n', - cache_control: undefined, }, + prompt: TEST_PROMPT, + }); + + expect(toolCalls).toStrictEqual([ { - type: 'tool_use', - id: 'toolu_1', - name: 'json', - input: { value: 'example value' }, - cache_control: undefined, + toolCallId: 'toolu_1', + toolCallType: 'function', + toolName: 'json', + args: '{"value":"example value"}', }, - ], - stopReason: 'tool_use', - }); + ]); + expect(finishReason).toStrictEqual('tool-calls'); - const { toolCalls, finishReason } = await model.doGenerate({ - inputFormat: 'prompt', - mode: { - type: 'object-tool', - tool: { - type: 'function', - name: 'json', - description: 'Respond with a JSON object.', - parameters: { - type: 'object', - properties: { value: { type: 'string' } }, - required: ['value'], - additionalProperties: false, - $schema: 'http://json-schema.org/draft-07/schema#', + // check request to Anthropic + expect(await server.calls[0].requestBody).toStrictEqual({ + max_tokens: 4096, + messages: [ + { + content: [{ text: 'Hello', type: 'text' }], + role: 'user', }, - }, - }, - prompt: TEST_PROMPT, - }); - - expect(toolCalls).toStrictEqual([ - { - toolCallId: 'toolu_1', - toolCallType: 'function', - toolName: 'json', - args: '{"value":"example value"}', - }, - ]); - expect(finishReason).toStrictEqual('tool-calls'); - - // check request to Anthropic - const requestBodyJson = await server.getRequestBodyJson(); - expect(requestBodyJson).toStrictEqual({ - max_tokens: 4096, - messages: [ - { - content: [{ text: 'Hello', type: 'text' }], - role: 'user', - }, - ], - model: 'claude-3-haiku-20240307', - tool_choice: { name: 'json', type: 'tool' }, - tools: [ - { - description: 'Respond with a JSON object.', - input_schema: { - $schema: 'http://json-schema.org/draft-07/schema#', - additionalProperties: false, - properties: { value: { type: 'string' } }, - required: ['value'], - type: 'object', + ], + model: 'claude-3-haiku-20240307', + tool_choice: { name: 'json', type: 'tool' }, + tools: [ + { + description: 'Respond with a JSON object.', + input_schema: { + $schema: 'http://json-schema.org/draft-07/schema#', + additionalProperties: false, + properties: { value: { type: 'string' } }, + required: ['value'], + type: 'object', + }, + name: 'json', }, - name: 'json', - }, - ], - }); - }); - - it('should extract usage', async () => { - prepareJsonResponse({ - usage: { input_tokens: 20, output_tokens: 5 }, + ], + }); }); - const { usage } = await model.doGenerate({ - inputFormat: 'prompt', - mode: { type: 'regular' }, - prompt: TEST_PROMPT, - }); + it('should extract usage', async () => { + prepareJsonResponse({ + usage: { input_tokens: 20, output_tokens: 5 }, + }); - expect(usage).toStrictEqual({ - promptTokens: 20, - completionTokens: 5, - }); - }); + const { usage } = await model.doGenerate({ + inputFormat: 'prompt', + mode: { type: 'regular' }, + prompt: TEST_PROMPT, + }); - it('should send additional response information', async () => { - prepareJsonResponse({ - id: 'test-id', - model: 'test-model', + expect(usage).toStrictEqual({ + promptTokens: 20, + completionTokens: 5, + }); }); - const { response } = await model.doGenerate({ - inputFormat: 'prompt', - mode: { type: 'regular' }, - prompt: TEST_PROMPT, - }); + it('should send additional response information', async () => { + prepareJsonResponse({ + id: 'test-id', + model: 'test-model', + }); - expect(response).toStrictEqual({ - id: 'test-id', - modelId: 'test-model', - }); - }); + const { response } = await model.doGenerate({ + inputFormat: 'prompt', + mode: { type: 'regular' }, + prompt: TEST_PROMPT, + }); - it('should expose the raw response headers', async () => { - prepareJsonResponse({}); - - server.responseHeaders = { - 'test-header': 'test-value', - }; - - const { rawResponse } = await model.doGenerate({ - inputFormat: 'prompt', - mode: { type: 'regular' }, - prompt: TEST_PROMPT, + expect(response).toStrictEqual({ + id: 'test-id', + modelId: 'test-model', + }); }); - expect(rawResponse?.headers).toStrictEqual({ - // default headers: - 'content-length': '237', - 'content-type': 'application/json', - - // custom header - 'test-header': 'test-value', - }); - }); - - it('should send the model id and settings', async () => { - prepareJsonResponse({}); - - await model.doGenerate({ - inputFormat: 'prompt', - mode: { type: 'regular' }, - prompt: TEST_PROMPT, - temperature: 0.5, - maxTokens: 100, - topP: 0.9, - topK: 0.1, - stopSequences: ['abc', 'def'], - frequencyPenalty: 0.15, + it('should expose the raw response headers', async () => { + prepareJsonResponse({ + headers: { + 'test-header': 'test-value', + }, + }); + + const { rawResponse } = await model.doGenerate({ + inputFormat: 'prompt', + mode: { type: 'regular' }, + prompt: TEST_PROMPT, + }); + + expect(rawResponse?.headers).toStrictEqual({ + // default headers: + 'content-length': '237', + 'content-type': 'application/json', + + // custom header + 'test-header': 'test-value', + }); + }); + + it('should send the model id and settings', async () => { + prepareJsonResponse({}); + + await model.doGenerate({ + inputFormat: 'prompt', + mode: { type: 'regular' }, + prompt: TEST_PROMPT, + temperature: 0.5, + maxTokens: 100, + topP: 0.9, + topK: 0.1, + stopSequences: ['abc', 'def'], + frequencyPenalty: 0.15, + }); + + expect(await server.calls[0].requestBody).toStrictEqual({ + model: 'claude-3-haiku-20240307', + max_tokens: 100, + stop_sequences: ['abc', 'def'], + temperature: 0.5, + top_k: 0.1, + top_p: 0.9, + messages: [ + { role: 'user', content: [{ type: 'text', text: 'Hello' }] }, + ], + }); }); - expect(await server.getRequestBodyJson()).toStrictEqual({ - model: 'claude-3-haiku-20240307', - max_tokens: 100, - stop_sequences: ['abc', 'def'], - temperature: 0.5, - top_k: 0.1, - top_p: 0.9, - messages: [{ role: 'user', content: [{ type: 'text', text: 'Hello' }] }], - }); - }); + it('should pass tools and toolChoice', async () => { + prepareJsonResponse({}); - it('should pass tools and toolChoice', async () => { - prepareJsonResponse({}); + await model.doGenerate({ + inputFormat: 'prompt', + mode: { + type: 'regular', + tools: [ + { + type: 'function', + name: 'test-tool', + parameters: { + type: 'object', + properties: { value: { type: 'string' } }, + required: ['value'], + additionalProperties: false, + $schema: 'http://json-schema.org/draft-07/schema#', + }, + }, + ], + toolChoice: { + type: 'tool', + toolName: 'test-tool', + }, + }, + prompt: TEST_PROMPT, + }); - await model.doGenerate({ - inputFormat: 'prompt', - mode: { - type: 'regular', + expect(await server.calls[0].requestBody).toStrictEqual({ + model: 'claude-3-haiku-20240307', + messages: [ + { role: 'user', content: [{ type: 'text', text: 'Hello' }] }, + ], + max_tokens: 4096, tools: [ { - type: 'function', name: 'test-tool', - parameters: { + input_schema: { type: 'object', properties: { value: { type: 'string' } }, required: ['value'], @@ -301,451 +334,438 @@ describe('doGenerate', () => { }, }, ], - toolChoice: { + tool_choice: { type: 'tool', - toolName: 'test-tool', - }, - }, - prompt: TEST_PROMPT, - }); - - expect(await server.getRequestBodyJson()).toStrictEqual({ - model: 'claude-3-haiku-20240307', - messages: [{ role: 'user', content: [{ type: 'text', text: 'Hello' }] }], - max_tokens: 4096, - tools: [ - { name: 'test-tool', - input_schema: { - type: 'object', - properties: { value: { type: 'string' } }, - required: ['value'], - additionalProperties: false, - $schema: 'http://json-schema.org/draft-07/schema#', - }, }, - ], - tool_choice: { - type: 'tool', - name: 'test-tool', - }, - }); - }); - - it('should pass headers', async () => { - prepareJsonResponse({ content: [] }); - - const provider = createAnthropic({ - apiKey: 'test-api-key', - headers: { - 'Custom-Provider-Header': 'provider-header-value', - }, - }); - - await provider('claude-3-haiku-20240307').doGenerate({ - inputFormat: 'prompt', - mode: { type: 'regular' }, - prompt: TEST_PROMPT, - headers: { - 'Custom-Request-Header': 'request-header-value', - }, - }); - - const requestHeaders = await server.getRequestHeaders(); - - expect(requestHeaders).toStrictEqual({ - 'anthropic-version': '2023-06-01', - 'content-type': 'application/json', - 'custom-provider-header': 'provider-header-value', - 'custom-request-header': 'request-header-value', - 'x-api-key': 'test-api-key', - }); - }); - - it('should support cache control', async () => { - prepareJsonResponse({ - usage: { - input_tokens: 20, - output_tokens: 50, - cache_creation_input_tokens: 10, - cache_read_input_tokens: 5, - }, + }); }); - const model = provider('claude-3-haiku-20240307', { - cacheControl: true, - }); + it('should pass headers', async () => { + prepareJsonResponse({ content: [] }); - const result = await model.doGenerate({ - mode: { type: 'regular' }, - inputFormat: 'messages', - prompt: [ - { - role: 'user', - content: [{ type: 'text', text: 'Hello' }], - providerMetadata: { - anthropic: { - cacheControl: { type: 'ephemeral' }, - }, - }, + const provider = createAnthropic({ + apiKey: 'test-api-key', + headers: { + 'Custom-Provider-Header': 'provider-header-value', }, - ], - }); - - expect(await server.getRequestBodyJson()).toStrictEqual({ - model: 'claude-3-haiku-20240307', - messages: [ - { - role: 'user', - content: [ - { - type: 'text', - text: 'Hello', - cache_control: { type: 'ephemeral' }, - }, - ], + }); + + await provider('claude-3-haiku-20240307').doGenerate({ + inputFormat: 'prompt', + mode: { type: 'regular' }, + prompt: TEST_PROMPT, + headers: { + 'Custom-Request-Header': 'request-header-value', }, - ], - max_tokens: 4096, - }); - - expect(result.providerMetadata).toStrictEqual({ - anthropic: { - cacheCreationInputTokens: 10, - cacheReadInputTokens: 5, - }, - }); - }); - - it('should send request body', async () => { - prepareJsonResponse({ content: [] }); - - const { request } = await model.doGenerate({ - inputFormat: 'prompt', - mode: { type: 'regular' }, - prompt: TEST_PROMPT, - }); - - expect(request).toStrictEqual({ - body: '{"model":"claude-3-haiku-20240307","max_tokens":4096,"messages":[{"role":"user","content":[{"type":"text","text":"Hello"}]}]}', - }); - }); -}); - -describe('doStream', () => { - const server = new StreamingTestServer( - 'https://api.anthropic.com/v1/messages', - ); - - server.setupTestEnvironment(); - - function prepareStreamResponse({ content }: { content: string[] }) { - server.responseChunks = [ - `data: {"type":"message_start","message":{"id":"msg_01KfpJoAEabmH2iHRRFjQMAG","type":"message","role":"assistant","content":[],"model":"claude-3-haiku-20240307","stop_reason":null,"stop_sequence":null,"usage":{"input_tokens":17,"output_tokens":1}} }\n\n`, - `data: {"type":"content_block_start","index":0,"content_block":{"type":"text","text":""} }\n\n`, - `data: {"type": "ping"}\n\n`, - ...content.map(text => { - return `data: {"type":"content_block_delta","index":0,"delta":{"type":"text_delta","text":"${text}"} }\n\n`; - }), - `data: {"type":"content_block_stop","index":0 }\n\n`, - `data: {"type":"message_delta","delta":{"stop_reason":"end_turn","stop_sequence":null},"usage":{"output_tokens":227} }\n\n`, - `data: {"type":"message_stop" }\n\n`, - ]; - } - - it('should stream text deltas', async () => { - prepareStreamResponse({ content: ['Hello', ', ', 'World!'] }); - - const { stream } = await model.doStream({ - inputFormat: 'prompt', - mode: { type: 'regular' }, - prompt: TEST_PROMPT, - }); + }); + + expect(await server.calls[0].requestHeaders).toStrictEqual({ + 'anthropic-version': '2023-06-01', + 'content-type': 'application/json', + 'custom-provider-header': 'provider-header-value', + 'custom-request-header': 'request-header-value', + 'x-api-key': 'test-api-key', + }); + }); + + it('should support cache control', async () => { + prepareJsonResponse({ + usage: { + input_tokens: 20, + output_tokens: 50, + cache_creation_input_tokens: 10, + cache_read_input_tokens: 5, + }, + }); - // note: space moved to last chunk bc of trimming - expect(await convertReadableStreamToArray(stream)).toStrictEqual([ - { - type: 'response-metadata', - id: 'msg_01KfpJoAEabmH2iHRRFjQMAG', - modelId: 'claude-3-haiku-20240307', - }, - { type: 'text-delta', textDelta: 'Hello' }, - { type: 'text-delta', textDelta: ', ' }, - { type: 'text-delta', textDelta: 'World!' }, - { - type: 'finish', - finishReason: 'stop', - usage: { promptTokens: 17, completionTokens: 227 }, - providerMetadata: undefined, - }, - ]); - }); + const model = provider('claude-3-haiku-20240307', { + cacheControl: true, + }); - it('should stream tool deltas', async () => { - server.responseChunks = [ - `data: {"type":"message_start","message":{"id":"msg_01GouTqNCGXzrj5LQ5jEkw67","type":"message","role":"assistant","model":"claude-3-haiku-20240307","stop_sequence":null,"usage":{"input_tokens":441,"output_tokens":2},"content":[],"stop_reason":null} }\n\n`, - `data: {"type":"content_block_start","index":0,"content_block":{"type":"text","text":""} }\n\n`, - `data: {"type": "ping"}\n\n`, - `data: {"type":"content_block_delta","index":0,"delta":{"type":"text_delta","text":"Okay"} }\n\n`, - `data: {"type":"content_block_delta","index":0,"delta":{"type":"text_delta","text":"!"} }\n\n`, - `data: {"type":"content_block_stop","index":0 }\n\n`, - `data: {"type":"content_block_start","index":1,"content_block":{"type":"tool_use","id":"toolu_01DBsB4vvYLnBDzZ5rBSxSLs","name":"test-tool","input":{}} }\n\n`, - `data: {"type":"content_block_delta","index":1,"delta":{"type":"input_json_delta","partial_json":""} }\n\n`, - `data: {"type":"content_block_delta","index":1,"delta":{"type":"input_json_delta","partial_json":"{\\"value"} }\n\n`, - `data: {"type":"content_block_delta","index":1,"delta":{"type":"input_json_delta","partial_json":"\\":"} }\n\n`, - `data: {"type":"content_block_delta","index":1,"delta":{"type":"input_json_delta","partial_json":"\\"Spark"} }\n\n`, - `data: {"type":"content_block_delta","index":1,"delta":{"type":"input_json_delta","partial_json":"le"} }\n\n`, - `data: {"type":"content_block_delta","index":1,"delta":{"type":"input_json_delta","partial_json":" Day\\"}"} }\n\n`, - `data: {"type":"content_block_stop","index":1 }\n\n`, - `data: {"type":"message_delta","delta":{"stop_reason":"tool_use","stop_sequence":null},"usage":{"output_tokens":65} }\n\n`, - `data: {"type":"message_stop" }\n\n`, - ]; - - const { stream } = await model.doStream({ - inputFormat: 'prompt', - mode: { - type: 'regular', - tools: [ + const result = await model.doGenerate({ + mode: { type: 'regular' }, + inputFormat: 'messages', + prompt: [ { - type: 'function', - name: 'test-tool', - parameters: { - type: 'object', - properties: { value: { type: 'string' } }, - required: ['value'], - additionalProperties: false, - $schema: 'http://json-schema.org/draft-07/schema#', + role: 'user', + content: [{ type: 'text', text: 'Hello' }], + providerMetadata: { + anthropic: { + cacheControl: { type: 'ephemeral' }, + }, }, }, ], - }, - prompt: TEST_PROMPT, - }); + }); - expect(await convertReadableStreamToArray(stream)).toStrictEqual([ - { - type: 'response-metadata', - id: 'msg_01GouTqNCGXzrj5LQ5jEkw67', - modelId: 'claude-3-haiku-20240307', - }, - { - type: 'text-delta', - textDelta: 'Okay', - }, - { - type: 'text-delta', - textDelta: '!', - }, - { - type: 'tool-call-delta', - toolCallId: 'toolu_01DBsB4vvYLnBDzZ5rBSxSLs', - toolCallType: 'function', - toolName: 'test-tool', - argsTextDelta: '', - }, - { - type: 'tool-call-delta', - toolCallId: 'toolu_01DBsB4vvYLnBDzZ5rBSxSLs', - toolCallType: 'function', - toolName: 'test-tool', - argsTextDelta: '{"value', - }, - { - type: 'tool-call-delta', - toolCallId: 'toolu_01DBsB4vvYLnBDzZ5rBSxSLs', - toolCallType: 'function', - toolName: 'test-tool', - argsTextDelta: '":', - }, - { - type: 'tool-call-delta', - toolCallId: 'toolu_01DBsB4vvYLnBDzZ5rBSxSLs', - toolCallType: 'function', - toolName: 'test-tool', - argsTextDelta: '"Spark', - }, - { - type: 'tool-call-delta', - toolCallId: 'toolu_01DBsB4vvYLnBDzZ5rBSxSLs', - toolCallType: 'function', - toolName: 'test-tool', - argsTextDelta: 'le', - }, - { - type: 'tool-call-delta', - toolCallId: 'toolu_01DBsB4vvYLnBDzZ5rBSxSLs', - toolCallType: 'function', - toolName: 'test-tool', - argsTextDelta: ' Day"}', - }, - { - type: 'tool-call', - toolCallId: 'toolu_01DBsB4vvYLnBDzZ5rBSxSLs', - toolCallType: 'function', - toolName: 'test-tool', - args: '{"value":"Sparkle Day"}', - }, - { - type: 'finish', - finishReason: 'tool-calls', - usage: { promptTokens: 441, completionTokens: 65 }, - providerMetadata: undefined, - }, - ]); - }); + expect(await server.calls[0].requestBody).toStrictEqual({ + model: 'claude-3-haiku-20240307', + messages: [ + { + role: 'user', + content: [ + { + type: 'text', + text: 'Hello', + cache_control: { type: 'ephemeral' }, + }, + ], + }, + ], + max_tokens: 4096, + }); - it('should forward error chunks', async () => { - server.responseChunks = [ - `data: {"type":"message_start","message":{"id":"msg_01KfpJoAEabmH2iHRRFjQMAG","type":"message","role":"assistant","content":[],"model":"claude-3-haiku-20240307","stop_reason":null,"stop_sequence":null,"usage":{"input_tokens":17,"output_tokens":1}} }\n\n`, - `data: {"type":"content_block_start","index":0,"content_block":{"type":"text","text":""} }\n\n`, - `data: {"type": "ping"}\n\n`, - `data: {"type":"error","error":{"type":"error","message":"test error"}}\n\n`, - ]; - - const { stream } = await model.doStream({ - inputFormat: 'prompt', - mode: { type: 'regular' }, - prompt: TEST_PROMPT, + expect(result.providerMetadata).toStrictEqual({ + anthropic: { + cacheCreationInputTokens: 10, + cacheReadInputTokens: 5, + }, + }); }); - expect(await convertReadableStreamToArray(stream)).toStrictEqual([ - { - type: 'response-metadata', - id: 'msg_01KfpJoAEabmH2iHRRFjQMAG', - modelId: 'claude-3-haiku-20240307', - }, - { type: 'error', error: { type: 'error', message: 'test error' } }, - ]); - }); - - it('should expose the raw response headers', async () => { - prepareStreamResponse({ content: [] }); + it('should send request body', async () => { + prepareJsonResponse({ content: [] }); - server.responseHeaders = { - 'test-header': 'test-value', - }; + const { request } = await model.doGenerate({ + inputFormat: 'prompt', + mode: { type: 'regular' }, + prompt: TEST_PROMPT, + }); - const { rawResponse } = await model.doStream({ - inputFormat: 'prompt', - mode: { type: 'regular' }, - prompt: TEST_PROMPT, + expect(request).toStrictEqual({ + body: '{"model":"claude-3-haiku-20240307","max_tokens":4096,"messages":[{"role":"user","content":[{"type":"text","text":"Hello"}]}]}', + }); }); + }); - expect(rawResponse?.headers).toStrictEqual({ - // default headers: - 'content-type': 'text/event-stream', - 'cache-control': 'no-cache', - connection: 'keep-alive', + describe('doStream', () => { + function prepareStreamResponse({ + content, + headers, + }: { + content: string[]; + headers?: Record; + }) { + server.urls['https://api.anthropic.com/v1/messages'].response = { + type: 'stream-chunks', + headers, + chunks: [ + `data: {"type":"message_start","message":{"id":"msg_01KfpJoAEabmH2iHRRFjQMAG","type":"message","role":"assistant","content":[],"model":"claude-3-haiku-20240307","stop_reason":null,"stop_sequence":null,"usage":{"input_tokens":17,"output_tokens":1}} }\n\n`, + `data: {"type":"content_block_start","index":0,"content_block":{"type":"text","text":""} }\n\n`, + `data: {"type": "ping"}\n\n`, + ...content.map(text => { + return `data: {"type":"content_block_delta","index":0,"delta":{"type":"text_delta","text":"${text}"} }\n\n`; + }), + `data: {"type":"content_block_stop","index":0 }\n\n`, + `data: {"type":"message_delta","delta":{"stop_reason":"end_turn","stop_sequence":null},"usage":{"output_tokens":227} }\n\n`, + `data: {"type":"message_stop" }\n\n`, + ], + }; + } - // custom header - 'test-header': 'test-value', - }); - }); + it('should stream text deltas', async () => { + prepareStreamResponse({ content: ['Hello', ', ', 'World!'] }); - it('should pass the messages and the model', async () => { - prepareStreamResponse({ content: [] }); + const { stream } = await model.doStream({ + inputFormat: 'prompt', + mode: { type: 'regular' }, + prompt: TEST_PROMPT, + }); - await model.doStream({ - inputFormat: 'prompt', - mode: { type: 'regular' }, - prompt: TEST_PROMPT, - }); + // note: space moved to last chunk bc of trimming + expect(await convertReadableStreamToArray(stream)).toStrictEqual([ + { + type: 'response-metadata', + id: 'msg_01KfpJoAEabmH2iHRRFjQMAG', + modelId: 'claude-3-haiku-20240307', + }, + { type: 'text-delta', textDelta: 'Hello' }, + { type: 'text-delta', textDelta: ', ' }, + { type: 'text-delta', textDelta: 'World!' }, + { + type: 'finish', + finishReason: 'stop', + usage: { promptTokens: 17, completionTokens: 227 }, + providerMetadata: undefined, + }, + ]); + }); + + it('should stream tool deltas', async () => { + server.urls['https://api.anthropic.com/v1/messages'].response = { + type: 'stream-chunks', + chunks: [ + `data: {"type":"message_start","message":{"id":"msg_01GouTqNCGXzrj5LQ5jEkw67","type":"message","role":"assistant","model":"claude-3-haiku-20240307","stop_sequence":null,"usage":{"input_tokens":441,"output_tokens":2},"content":[],"stop_reason":null} }\n\n`, + `data: {"type":"content_block_start","index":0,"content_block":{"type":"text","text":""} }\n\n`, + `data: {"type": "ping"}\n\n`, + `data: {"type":"content_block_delta","index":0,"delta":{"type":"text_delta","text":"Okay"} }\n\n`, + `data: {"type":"content_block_delta","index":0,"delta":{"type":"text_delta","text":"!"} }\n\n`, + `data: {"type":"content_block_stop","index":0 }\n\n`, + `data: {"type":"content_block_start","index":1,"content_block":{"type":"tool_use","id":"toolu_01DBsB4vvYLnBDzZ5rBSxSLs","name":"test-tool","input":{}} }\n\n`, + `data: {"type":"content_block_delta","index":1,"delta":{"type":"input_json_delta","partial_json":""} }\n\n`, + `data: {"type":"content_block_delta","index":1,"delta":{"type":"input_json_delta","partial_json":"{\\"value"} }\n\n`, + `data: {"type":"content_block_delta","index":1,"delta":{"type":"input_json_delta","partial_json":"\\":"} }\n\n`, + `data: {"type":"content_block_delta","index":1,"delta":{"type":"input_json_delta","partial_json":"\\"Spark"} }\n\n`, + `data: {"type":"content_block_delta","index":1,"delta":{"type":"input_json_delta","partial_json":"le"} }\n\n`, + `data: {"type":"content_block_delta","index":1,"delta":{"type":"input_json_delta","partial_json":" Day\\"}"} }\n\n`, + `data: {"type":"content_block_stop","index":1 }\n\n`, + `data: {"type":"message_delta","delta":{"stop_reason":"tool_use","stop_sequence":null},"usage":{"output_tokens":65} }\n\n`, + `data: {"type":"message_stop" }\n\n`, + ], + }; - expect(await server.getRequestBodyJson()).toStrictEqual({ - stream: true, - model: 'claude-3-haiku-20240307', - max_tokens: 4096, // default value - messages: [{ role: 'user', content: [{ type: 'text', text: 'Hello' }] }], + const { stream } = await model.doStream({ + inputFormat: 'prompt', + mode: { + type: 'regular', + tools: [ + { + type: 'function', + name: 'test-tool', + parameters: { + type: 'object', + properties: { value: { type: 'string' } }, + required: ['value'], + additionalProperties: false, + $schema: 'http://json-schema.org/draft-07/schema#', + }, + }, + ], + }, + prompt: TEST_PROMPT, + }); + + expect(await convertReadableStreamToArray(stream)).toStrictEqual([ + { + type: 'response-metadata', + id: 'msg_01GouTqNCGXzrj5LQ5jEkw67', + modelId: 'claude-3-haiku-20240307', + }, + { + type: 'text-delta', + textDelta: 'Okay', + }, + { + type: 'text-delta', + textDelta: '!', + }, + { + type: 'tool-call-delta', + toolCallId: 'toolu_01DBsB4vvYLnBDzZ5rBSxSLs', + toolCallType: 'function', + toolName: 'test-tool', + argsTextDelta: '', + }, + { + type: 'tool-call-delta', + toolCallId: 'toolu_01DBsB4vvYLnBDzZ5rBSxSLs', + toolCallType: 'function', + toolName: 'test-tool', + argsTextDelta: '{"value', + }, + { + type: 'tool-call-delta', + toolCallId: 'toolu_01DBsB4vvYLnBDzZ5rBSxSLs', + toolCallType: 'function', + toolName: 'test-tool', + argsTextDelta: '":', + }, + { + type: 'tool-call-delta', + toolCallId: 'toolu_01DBsB4vvYLnBDzZ5rBSxSLs', + toolCallType: 'function', + toolName: 'test-tool', + argsTextDelta: '"Spark', + }, + { + type: 'tool-call-delta', + toolCallId: 'toolu_01DBsB4vvYLnBDzZ5rBSxSLs', + toolCallType: 'function', + toolName: 'test-tool', + argsTextDelta: 'le', + }, + { + type: 'tool-call-delta', + toolCallId: 'toolu_01DBsB4vvYLnBDzZ5rBSxSLs', + toolCallType: 'function', + toolName: 'test-tool', + argsTextDelta: ' Day"}', + }, + { + type: 'tool-call', + toolCallId: 'toolu_01DBsB4vvYLnBDzZ5rBSxSLs', + toolCallType: 'function', + toolName: 'test-tool', + args: '{"value":"Sparkle Day"}', + }, + { + type: 'finish', + finishReason: 'tool-calls', + usage: { promptTokens: 441, completionTokens: 65 }, + providerMetadata: undefined, + }, + ]); }); - }); - it('should pass headers', async () => { - prepareStreamResponse({ content: [] }); + it('should forward error chunks', async () => { + server.urls['https://api.anthropic.com/v1/messages'].response = { + type: 'stream-chunks', + chunks: [ + `data: {"type":"message_start","message":{"id":"msg_01KfpJoAEabmH2iHRRFjQMAG","type":"message","role":"assistant","content":[],"model":"claude-3-haiku-20240307","stop_reason":null,"stop_sequence":null,"usage":{"input_tokens":17,"output_tokens":1}} }\n\n`, + `data: {"type":"content_block_start","index":0,"content_block":{"type":"text","text":""} }\n\n`, + `data: {"type": "ping"}\n\n`, + `data: {"type":"error","error":{"type":"error","message":"test error"}}\n\n`, + ], + }; - const provider = createAnthropic({ - apiKey: 'test-api-key', - headers: { - 'Custom-Provider-Header': 'provider-header-value', - }, - }); + const { stream } = await model.doStream({ + inputFormat: 'prompt', + mode: { type: 'regular' }, + prompt: TEST_PROMPT, + }); - await provider('claude-3-haiku-20240307').doStream({ - inputFormat: 'prompt', - mode: { type: 'regular' }, - prompt: TEST_PROMPT, - headers: { - 'Custom-Request-Header': 'request-header-value', - }, + expect(await convertReadableStreamToArray(stream)).toStrictEqual([ + { + type: 'response-metadata', + id: 'msg_01KfpJoAEabmH2iHRRFjQMAG', + modelId: 'claude-3-haiku-20240307', + }, + { type: 'error', error: { type: 'error', message: 'test error' } }, + ]); + }); + + it('should expose the raw response headers', async () => { + prepareStreamResponse({ + content: [], + headers: { 'test-header': 'test-value' }, + }); + + const { rawResponse } = await model.doStream({ + inputFormat: 'prompt', + mode: { type: 'regular' }, + prompt: TEST_PROMPT, + }); + + expect(rawResponse?.headers).toStrictEqual({ + // default headers: + 'content-type': 'text/event-stream', + 'cache-control': 'no-cache', + connection: 'keep-alive', + + // custom header + 'test-header': 'test-value', + }); + }); + + it('should pass the messages and the model', async () => { + prepareStreamResponse({ content: [] }); + + await model.doStream({ + inputFormat: 'prompt', + mode: { type: 'regular' }, + prompt: TEST_PROMPT, + }); + + expect(await server.calls[0].requestBody).toStrictEqual({ + stream: true, + model: 'claude-3-haiku-20240307', + max_tokens: 4096, // default value + messages: [ + { role: 'user', content: [{ type: 'text', text: 'Hello' }] }, + ], + }); }); - const requestHeaders = await server.getRequestHeaders(); + it('should pass headers', async () => { + prepareStreamResponse({ content: [] }); - expect(requestHeaders).toStrictEqual({ - 'anthropic-version': '2023-06-01', - 'content-type': 'application/json', - 'custom-provider-header': 'provider-header-value', - 'custom-request-header': 'request-header-value', - 'x-api-key': 'test-api-key', - }); - }); + const provider = createAnthropic({ + apiKey: 'test-api-key', + headers: { + 'Custom-Provider-Header': 'provider-header-value', + }, + }); + + await provider('claude-3-haiku-20240307').doStream({ + inputFormat: 'prompt', + mode: { type: 'regular' }, + prompt: TEST_PROMPT, + headers: { + 'Custom-Request-Header': 'request-header-value', + }, + }); + + expect(server.calls[0].requestHeaders).toStrictEqual({ + 'anthropic-version': '2023-06-01', + 'content-type': 'application/json', + 'custom-provider-header': 'provider-header-value', + 'custom-request-header': 'request-header-value', + 'x-api-key': 'test-api-key', + }); + }); + + it('should support cache control', async () => { + server.urls['https://api.anthropic.com/v1/messages'].response = { + type: 'stream-chunks', + chunks: [ + `data: {"type":"message_start","message":{"id":"msg_01KfpJoAEabmH2iHRRFjQMAG","type":"message","role":"assistant","content":[],` + + `"model":"claude-3-haiku-20240307","stop_reason":null,"stop_sequence":null,"usage":` + + // send cache output tokens: + `{"input_tokens":17,"output_tokens":1,"cache_creation_input_tokens":10,"cache_read_input_tokens":5}} }\n\n`, + `data: {"type":"content_block_start","index":0,"content_block":{"type":"text","text":""} }\n\n`, + `data: {"type": "ping"}\n\n`, + `data: {"type":"content_block_delta","index":0,"delta":{"type":"text_delta","text":"${'Hello'}"} }\n\n`, + `data: {"type":"content_block_stop","index":0 }\n\n`, + `data: {"type":"message_delta","delta":{"stop_reason":"end_turn","stop_sequence":null},"usage":{"output_tokens":227} }\n\n`, + `data: {"type":"message_stop" }\n\n`, + ], + }; - it('should support cache control', async () => { - server.responseChunks = [ - `data: {"type":"message_start","message":{"id":"msg_01KfpJoAEabmH2iHRRFjQMAG","type":"message","role":"assistant","content":[],` + - `"model":"claude-3-haiku-20240307","stop_reason":null,"stop_sequence":null,"usage":` + - // send cache output tokens: - `{"input_tokens":17,"output_tokens":1,"cache_creation_input_tokens":10,"cache_read_input_tokens":5}} }\n\n`, - `data: {"type":"content_block_start","index":0,"content_block":{"type":"text","text":""} }\n\n`, - `data: {"type": "ping"}\n\n`, - `data: {"type":"content_block_delta","index":0,"delta":{"type":"text_delta","text":"${'Hello'}"} }\n\n`, - `data: {"type":"content_block_stop","index":0 }\n\n`, - `data: {"type":"message_delta","delta":{"stop_reason":"end_turn","stop_sequence":null},"usage":{"output_tokens":227} }\n\n`, - `data: {"type":"message_stop" }\n\n`, - ]; - - const model = provider('claude-3-haiku-20240307', { - cacheControl: true, - }); + const model = provider('claude-3-haiku-20240307', { + cacheControl: true, + }); - const { stream } = await model.doStream({ - inputFormat: 'prompt', - mode: { type: 'regular' }, - prompt: TEST_PROMPT, - }); + const { stream } = await model.doStream({ + inputFormat: 'prompt', + mode: { type: 'regular' }, + prompt: TEST_PROMPT, + }); - // note: space moved to last chunk bc of trimming - expect(await convertReadableStreamToArray(stream)).toStrictEqual([ - { - type: 'response-metadata', - id: 'msg_01KfpJoAEabmH2iHRRFjQMAG', - modelId: 'claude-3-haiku-20240307', - }, - { type: 'text-delta', textDelta: 'Hello' }, - { - type: 'finish', - finishReason: 'stop', - usage: { promptTokens: 17, completionTokens: 227 }, - providerMetadata: { - anthropic: { - cacheCreationInputTokens: 10, - cacheReadInputTokens: 5, + // note: space moved to last chunk bc of trimming + expect(await convertReadableStreamToArray(stream)).toStrictEqual([ + { + type: 'response-metadata', + id: 'msg_01KfpJoAEabmH2iHRRFjQMAG', + modelId: 'claude-3-haiku-20240307', + }, + { type: 'text-delta', textDelta: 'Hello' }, + { + type: 'finish', + finishReason: 'stop', + usage: { promptTokens: 17, completionTokens: 227 }, + providerMetadata: { + anthropic: { + cacheCreationInputTokens: 10, + cacheReadInputTokens: 5, + }, }, }, - }, - ]); - }); + ]); + }); - it('should send request body', async () => { - prepareStreamResponse({ content: [] }); + it('should send request body', async () => { + prepareStreamResponse({ content: [] }); - const { request } = await model.doStream({ - inputFormat: 'prompt', - mode: { type: 'regular' }, - prompt: TEST_PROMPT, - }); + const { request } = await model.doStream({ + inputFormat: 'prompt', + mode: { type: 'regular' }, + prompt: TEST_PROMPT, + }); - expect(request).toStrictEqual({ - body: '{"model":"claude-3-haiku-20240307","max_tokens":4096,"messages":[{"role":"user","content":[{"type":"text","text":"Hello"}]}],"stream":true}', + expect(request).toStrictEqual({ + body: '{"model":"claude-3-haiku-20240307","max_tokens":4096,"messages":[{"role":"user","content":[{"type":"text","text":"Hello"}]}],"stream":true}', + }); }); }); }); diff --git a/packages/fireworks/src/fireworks-image-model.test.ts b/packages/fireworks/src/fireworks-image-model.test.ts index f391b6046cdc..cce2035d656a 100644 --- a/packages/fireworks/src/fireworks-image-model.test.ts +++ b/packages/fireworks/src/fireworks-image-model.test.ts @@ -1,7 +1,7 @@ -import { BinaryTestServer } from '@ai-sdk/provider-utils/test'; +import { FetchFunction } from '@ai-sdk/provider-utils'; +import { createTestServer } from '@ai-sdk/provider-utils/test'; import { describe, expect, it } from 'vitest'; import { FireworksImageModel } from './fireworks-image-model'; -import { FetchFunction } from '@ai-sdk/provider-utils'; const prompt = 'A cute baby sea otter'; @@ -31,35 +31,26 @@ function createSizeModel() { ); } -function createStabilityModel() { - return new FireworksImageModel('accounts/stability/models/sd3', { - provider: 'fireworks', - baseURL: 'https://api.stability.ai', - headers: () => ({ 'api-key': 'test-key' }), +describe('FireworksImageModel', () => { + const server = createTestServer({ + 'https://api.example.com/*': { + response: { + type: 'binary', + body: Buffer.from('test-binary-content'), + }, + }, + 'https://api.size-example.com/*': { + response: { + type: 'binary', + body: Buffer.from('test-binary-content'), + }, + }, }); -} - -const basicUrl = - 'https://api.example.com/workflows/accounts/fireworks/models/flux-1-dev-fp8/text_to_image'; -const sizeUrl = - 'https://api.size-example.com/image_generation/accounts/fireworks/models/playground-v2-5-1024px-aesthetic'; -const stabilityUrl = - 'https://api.stability.ai/v2beta/stable-image/generate/sd3'; -describe('FireworksImageModel', () => { describe('doGenerate', () => { - const server = new BinaryTestServer([basicUrl, sizeUrl, stabilityUrl]); - server.setupTestEnvironment(); - - function prepareBinaryResponse(url: string) { - const mockImageBuffer = Buffer.from('mock-image-data'); - server.setResponseFor(url, { body: mockImageBuffer }); - } - it('should pass the correct parameters including aspect ratio and seed', async () => { - prepareBinaryResponse(basicUrl); - const model = createBasicModel(); + await model.doGenerate({ prompt, n: 1, @@ -69,8 +60,7 @@ describe('FireworksImageModel', () => { providerOptions: { fireworks: { additional_param: 'value' } }, }); - const request = await server.getRequestDataFor(basicUrl); - expect(await request.bodyJson()).toStrictEqual({ + expect(await server.calls[0].requestBody).toStrictEqual({ prompt, aspect_ratio: '16:9', seed: 42, @@ -78,9 +68,25 @@ describe('FireworksImageModel', () => { }); }); - it('should pass headers', async () => { - prepareBinaryResponse(basicUrl); + it('should call the correct url', async () => { + const model = createBasicModel(); + + await model.doGenerate({ + prompt, + n: 1, + size: undefined, + aspectRatio: '16:9', + seed: 42, + providerOptions: { fireworks: { additional_param: 'value' } }, + }); + expect(server.calls[0].requestMethod).toStrictEqual('POST'); + expect(server.calls[0].requestUrl).toStrictEqual( + 'https://api.example.com/workflows/accounts/fireworks/models/flux-1-dev-fp8/text_to_image', + ); + }); + + it('should pass headers', async () => { const modelWithHeaders = createBasicModel({ headers: () => ({ 'Custom-Provider-Header': 'provider-header-value', @@ -99,8 +105,7 @@ describe('FireworksImageModel', () => { }, }); - const request = await server.getRequestDataFor(basicUrl); - expect(request.headers()).toStrictEqual({ + expect(server.calls[0].requestHeaders).toStrictEqual({ 'content-type': 'application/json', 'custom-provider-header': 'provider-header-value', 'custom-request-header': 'request-header-value', @@ -108,7 +113,9 @@ describe('FireworksImageModel', () => { }); it('should handle empty response body', async () => { - server.setResponseFor(basicUrl, { body: null }); + server.urls['https://api.example.com/*'].response = { + type: 'empty', + }; const model = createBasicModel(); await expect( @@ -123,7 +130,7 @@ describe('FireworksImageModel', () => { ).rejects.toMatchObject({ message: 'Response body is empty', statusCode: 200, - url: basicUrl, + url: 'https://api.example.com/workflows/accounts/fireworks/models/flux-1-dev-fp8/text_to_image', requestBodyValues: { prompt: 'A cute baby sea otter', }, @@ -131,10 +138,11 @@ describe('FireworksImageModel', () => { }); it('should handle API errors', async () => { - server.setResponseFor(basicUrl, { + server.urls['https://api.example.com/*'].response = { + type: 'error', status: 400, - body: Buffer.from('Bad Request'), - }); + body: 'Bad Request', + }; const model = createBasicModel(); await expect( @@ -149,7 +157,7 @@ describe('FireworksImageModel', () => { ).rejects.toMatchObject({ message: 'Bad Request', statusCode: 400, - url: basicUrl, + url: 'https://api.example.com/workflows/accounts/fireworks/models/flux-1-dev-fp8/text_to_image', requestBodyValues: { prompt: 'A cute baby sea otter', }, @@ -158,8 +166,6 @@ describe('FireworksImageModel', () => { }); it('should handle size parameter for supported models', async () => { - prepareBinaryResponse(sizeUrl); - const sizeModel = createSizeModel(); await sizeModel.doGenerate({ @@ -171,8 +177,7 @@ describe('FireworksImageModel', () => { providerOptions: {}, }); - const request = await server.getRequestDataFor(sizeUrl); - expect(await request.bodyJson()).toStrictEqual({ + expect(await server.calls[0].requestBody).toStrictEqual({ prompt, width: '1024', height: '768', @@ -180,49 +185,48 @@ describe('FireworksImageModel', () => { }); }); - it('should return appropriate warnings based on model capabilities', async () => { - prepareBinaryResponse(basicUrl); - - // Test workflow model (supports aspectRatio but not size) - const model = createBasicModel(); - const result1 = await model.doGenerate({ - prompt, - n: 1, - size: '1024x1024', - aspectRatio: '1:1', - seed: 123, - providerOptions: {}, - }); + describe('warnings', () => { + it('should return size warning on workflow model', async () => { + const model = createBasicModel(); - expect(result1.warnings).toContainEqual({ - type: 'unsupported-setting', - setting: 'size', - details: - 'This model does not support the `size` option. Use `aspectRatio` instead.', + const result1 = await model.doGenerate({ + prompt, + n: 1, + size: '1024x1024', + aspectRatio: '1:1', + seed: 123, + providerOptions: {}, + }); + + expect(result1.warnings).toContainEqual({ + type: 'unsupported-setting', + setting: 'size', + details: + 'This model does not support the `size` option. Use `aspectRatio` instead.', + }); }); - // Test size-supporting model - prepareBinaryResponse(sizeUrl); - const sizeModel = createSizeModel(); + it('should return aspectRatio warning on size-supporting model', async () => { + const sizeModel = createSizeModel(); - const result2 = await sizeModel.doGenerate({ - prompt, - n: 1, - size: '1024x1024', - aspectRatio: '1:1', - seed: 123, - providerOptions: {}, - }); + const result2 = await sizeModel.doGenerate({ + prompt, + n: 1, + size: '1024x1024', + aspectRatio: '1:1', + seed: 123, + providerOptions: {}, + }); - expect(result2.warnings).toContainEqual({ - type: 'unsupported-setting', - setting: 'aspectRatio', - details: 'This model does not support the `aspectRatio` option.', + expect(result2.warnings).toContainEqual({ + type: 'unsupported-setting', + setting: 'aspectRatio', + details: 'This model does not support the `aspectRatio` option.', + }); }); }); it('should respect the abort signal', async () => { - prepareBinaryResponse(basicUrl); const model = createBasicModel(); const controller = new AbortController(); diff --git a/packages/provider-utils/src/test/binary-test-server.test.ts b/packages/provider-utils/src/test/binary-test-server.test.ts deleted file mode 100644 index 321cd49f2710..000000000000 --- a/packages/provider-utils/src/test/binary-test-server.test.ts +++ /dev/null @@ -1,186 +0,0 @@ -import { - describe, - it, - expect, - beforeEach, - afterEach, - beforeAll, - afterAll, - vi, -} from 'vitest'; -import { BinaryTestServer } from './binary-test-server'; - -describe('BinaryTestServer', () => { - let server: BinaryTestServer; - - beforeEach(() => { - vi.clearAllMocks(); - }); - - afterEach(() => { - vi.restoreAllMocks(); - }); - - describe('constructor', () => { - it('should initialize with a single URL', () => { - const server = new BinaryTestServer('http://example.com'); - expect(server.server).toBeDefined(); - }); - - it('should initialize with multiple URLs', () => { - const server = new BinaryTestServer([ - 'http://example.com', - 'http://test.com', - ]); - expect(server.server).toBeDefined(); - }); - }); - - describe('setResponseFor', () => { - beforeAll(() => { - server = new BinaryTestServer('http://example.com'); - server.server.listen(); - }); - - afterAll(() => { - server.server.close(); - }); - - it('should set response options for a valid URL', () => { - const buffer = Buffer.from('test data'); - server.setResponseFor('http://example.com/', { - body: buffer, - headers: { 'content-type': 'application/octet-stream' }, - status: 201, - }); - }); - - it('should throw error for invalid URL', () => { - expect(() => - server.setResponseFor('http://invalid.com', { status: 200 }), - ).toThrow('No endpoint configured for URL'); - }); - }); - - describe('request handling', () => { - beforeAll(() => { - server = new BinaryTestServer('http://example.com'); - server.server.listen(); - }); - - afterAll(() => { - server.server.close(); - }); - - beforeEach(() => { - server.server.resetHandlers(); - }); - - it('should handle JSON requests', async () => { - const testData = { test: 'data' }; - const fetchSpy = vi.spyOn(global, 'fetch'); - - await fetch('http://example.com', { - method: 'POST', - headers: { 'content-type': 'application/json' }, - body: JSON.stringify(testData), - }); - - expect(fetchSpy).toHaveBeenCalledWith( - 'http://example.com', - expect.objectContaining({ - method: 'POST', - headers: expect.objectContaining({ - 'content-type': 'application/json', - }), - }), - ); - - const requestData = await server.getRequestDataFor('http://example.com/'); - const bodyJson = await requestData.bodyJson(); - expect(bodyJson).toEqual(testData); - }); - - it('should handle form data requests', async () => { - const formData = new FormData(); - formData.append('field', 'value'); - const fetchSpy = vi.spyOn(global, 'fetch'); - - await fetch('http://example.com', { - method: 'POST', - body: formData, - }); - - expect(fetchSpy).toHaveBeenCalledWith( - 'http://example.com', - expect.objectContaining({ - method: 'POST', - body: expect.any(FormData), - }), - ); - - const requestData = await server.getRequestDataFor('http://example.com'); - const formDataReceived = await requestData.bodyFormData(); - expect(formDataReceived.get('field')).toBe('value'); - }); - - it('should handle custom response configurations', async () => { - const responseBuffer = Buffer.from('test response'); - const fetchSpy = vi.spyOn(global, 'fetch'); - - server.setResponseFor('http://example.com', { - body: responseBuffer, - headers: { 'x-custom': 'test' }, - status: 201, - }); - - const response = await fetch('http://example.com', { method: 'POST' }); - - expect(fetchSpy).toHaveBeenCalledWith( - 'http://example.com', - expect.objectContaining({ method: 'POST' }), - ); - - expect(response.status).toBe(201); - expect(response.headers.get('x-custom')).toBe('test'); - const responseData = await response.arrayBuffer(); - expect(Buffer.from(responseData)).toEqual(responseBuffer); - }); - }); - - describe('URL handling', () => { - let server: BinaryTestServer; - - beforeEach(() => { - server = new BinaryTestServer('http://example.com'); - server.server.listen(); - // Set default response - server.setResponseFor('http://example.com', { - status: 200, - body: null, - }); - }); - - afterEach(() => { - server.server.resetHandlers(); - server.server.close(); - }); - - it('should handle search params', async () => { - const response = await fetch('http://example.com?param=value', { - method: 'POST', - body: JSON.stringify({ test: true }), - }); - - expect(response.status).toBe(200); - - const requestData = await server.getRequestDataFor('http://example.com'); - expect(requestData.urlSearchParams().get('param')).toBe('value'); - }); - - it('should handle relative URLs', () => { - const server = new BinaryTestServer('/api/endpoint'); - expect(server.server).toBeDefined(); - }); - }); -}); diff --git a/packages/provider-utils/src/test/binary-test-server.ts b/packages/provider-utils/src/test/binary-test-server.ts deleted file mode 100644 index 73fe5effb329..000000000000 --- a/packages/provider-utils/src/test/binary-test-server.ts +++ /dev/null @@ -1,131 +0,0 @@ -import { HttpResponse, http } from 'msw'; -import { SetupServer, setupServer } from 'msw/node'; - -export class BinaryTestServer { - readonly server: SetupServer; - private endpoints: Map< - string, - { - responseBody: Buffer | null; - responseHeaders: Record; - responseStatus: number; - request: Request | undefined; - } - > = new Map(); - - constructor(urls: string | string[]) { - const urlList = Array.isArray(urls) ? urls : [urls]; - - // Initialize endpoints - urlList.forEach(url => { - const normalizedUrl = this.normalizeUrl(url); - this.endpoints.set(normalizedUrl, { - responseBody: null, - responseHeaders: {}, - responseStatus: 200, - request: undefined, - }); - }); - - this.server = setupServer( - ...urlList.map(url => - http.post(this.normalizeUrl(url), ({ request }) => { - const endpoint = this.endpoints.get(this.normalizeUrl(request.url)); - if (!endpoint) { - return new HttpResponse(null, { status: 500 }); - } - endpoint.request = request; - - if (endpoint.responseBody === null) { - return new HttpResponse(null, { status: endpoint.responseStatus }); - } - - return new HttpResponse(endpoint.responseBody, { - status: endpoint.responseStatus, - headers: endpoint.responseHeaders, - }); - }), - ), - ); - } - - private normalizeUrl(url: string): string { - try { - // Parse URL and remove search params for endpoint matching - const urlObj = new URL(url); - urlObj.search = ''; // Clear search params for matching - const normalized = urlObj.toString(); - return normalized.endsWith('/') ? normalized.slice(0, -1) : normalized; - } catch { - // If not a valid URL, assume it's a path and return as-is - return url.endsWith('/') ? url.slice(0, -1) : url; - } - } - - setResponseFor( - url: string, - options: { - body?: Buffer | null; - headers?: Record; - status?: number; - }, - ) { - // Normalize the URL before lookup - const normalizedUrl = this.normalizeUrl(url); - const endpoint = this.endpoints.get(normalizedUrl); - if (!endpoint) { - throw new Error(`No endpoint configured for URL: ${url}`); - } - if (options.body !== undefined) endpoint.responseBody = options.body; - if (options.headers) endpoint.responseHeaders = options.headers; - if (options.status) endpoint.responseStatus = options.status; - } - - async getRequestDataFor(url: string) { - // Normalize the URL before lookup - const normalizedUrl = this.normalizeUrl(url); - const endpoint = this.endpoints.get(normalizedUrl); - if (!endpoint) { - throw new Error(`No endpoint configured for URL: ${url}`); - } - expect(endpoint.request).toBeDefined(); - - return { - bodyJson: async () => { - const text = await endpoint.request!.text(); - return JSON.parse(text); - }, - bodyFormData: async () => { - const contentType = endpoint.request!.headers.get('content-type'); - if (contentType?.includes('multipart/form-data')) { - return endpoint.request!.formData(); - } - throw new Error('Request content-type is not multipart/form-data'); - }, - headers: () => { - const headersObject: Record = {}; - endpoint.request!.headers.forEach((value, key) => { - headersObject[key] = value; - }); - return headersObject; - }, - urlSearchParams: () => new URL(endpoint.request!.url).searchParams, - url: () => new URL(endpoint.request!.url).toString(), - }; - } - - setupTestEnvironment() { - beforeAll(() => this.server.listen()); - beforeEach(() => { - // Reset all endpoints - this.endpoints.forEach(endpoint => { - endpoint.responseBody = null; - endpoint.request = undefined; - endpoint.responseHeaders = {}; - endpoint.responseStatus = 200; - }); - }); - afterEach(() => this.server.resetHandlers()); - afterAll(() => this.server.close()); - } -} diff --git a/packages/provider-utils/src/test/index.ts b/packages/provider-utils/src/test/index.ts index 9451e80bc938..874b38ed8148 100644 --- a/packages/provider-utils/src/test/index.ts +++ b/packages/provider-utils/src/test/index.ts @@ -1,4 +1,3 @@ -export * from './binary-test-server'; export * from './convert-array-to-async-iterable'; export * from './convert-array-to-readable-stream'; export * from './convert-async-iterable-to-array'; diff --git a/packages/provider-utils/src/test/json-test-server.ts b/packages/provider-utils/src/test/json-test-server.ts index 212e80d102f7..2916961f6821 100644 --- a/packages/provider-utils/src/test/json-test-server.ts +++ b/packages/provider-utils/src/test/json-test-server.ts @@ -1,6 +1,9 @@ import { HttpResponse, http } from 'msw'; import { SetupServer, setupServer } from 'msw/node'; +/** + * @deprecated Use createTestServer instead + */ export class JsonTestServer { readonly server: SetupServer; @@ -9,6 +12,9 @@ export class JsonTestServer { request: Request | undefined; + /** + * @deprecated Use createTestServer instead + */ constructor(url: string) { const responseBodyJson = () => this.responseBodyJson; diff --git a/packages/provider-utils/src/test/streaming-test-server.ts b/packages/provider-utils/src/test/streaming-test-server.ts index ffc9754605a4..99e24ce79630 100644 --- a/packages/provider-utils/src/test/streaming-test-server.ts +++ b/packages/provider-utils/src/test/streaming-test-server.ts @@ -1,6 +1,9 @@ import { HttpResponse, http } from 'msw'; import { SetupServer, setupServer } from 'msw/node'; +/** + * @deprecated Use createTestServer instead + */ export class StreamingTestServer { readonly server: SetupServer; @@ -9,6 +12,9 @@ export class StreamingTestServer { request: Request | undefined; + /** + * @deprecated Use createTestServer instead + */ constructor(url: string) { const responseChunks = () => this.responseChunks; diff --git a/packages/provider-utils/src/test/unified-test-server.ts b/packages/provider-utils/src/test/unified-test-server.ts index 9786e56c5713..2c0bf3fd9cd6 100644 --- a/packages/provider-utils/src/test/unified-test-server.ts +++ b/packages/provider-utils/src/test/unified-test-server.ts @@ -1,64 +1,70 @@ -import { JSONValue } from '@ai-sdk/provider'; -import { http, HttpResponse } from 'msw'; +import { http, HttpResponse, JsonBodyType } from 'msw'; import { setupServer } from 'msw/node'; +import { convertArrayToReadableStream } from './convert-array-to-readable-stream'; -export type UrlHandler = - | { - type: 'json-value'; - response?: { +export type UrlHandler = { + response?: + | { + type: 'json-value'; headers?: Record; - body: JSONValue; - }; - } - | { - type: 'binary'; - response?: { + body: JsonBodyType; + } + | { + type: 'stream-chunks'; + headers?: Record; + chunks: Array; + } + | { + type: 'binary'; headers?: Record; body: Buffer; + } + | { + type: 'empty'; + headers?: Record; + status?: number; + } + | { + type: 'error'; + headers?: Record; + status?: number; + body?: string; }; - }; - -export type FullUrlHandler = - | { - type: 'json-value'; - response: - | { - headers: Record | undefined; - body: JSONValue; - } - | undefined; - } - | { - type: 'binary'; - response: - | { - headers: Record | undefined; - body: Buffer; - } - | undefined; - }; +}; -// Mapped type for URLS -export type FullHandlers = { - [url in keyof URLS]: URLS[url] extends { type: 'json-value' } - ? { +export type FullUrlHandler = { + response: + | { type: 'json-value'; - response: - | { - headers?: Record; - body: JSONValue; - } - | undefined; + headers?: Record; + body: JsonBodyType; + } + | { + type: 'stream-chunks'; + headers?: Record; + chunks: Array; } - : { + | { type: 'binary'; - response: - | { - headers?: Record; - body: Buffer; - } - | undefined; - }; + headers?: Record; + body: Buffer; + } + | { + type: 'error'; + headers?: Record; + status: number; + body?: string; + } + | { + type: 'empty'; + headers?: Record; + status?: number; + } + | undefined; +}; + +export type FullHandlers = { + [url in keyof URLS]: FullUrlHandler; }; class TestServerCall { @@ -99,15 +105,24 @@ export function createTestServer( urls: FullHandlers; calls: TestServerCall[]; } { + const originalRoutes = structuredClone(routes); // deep copy + const mswServer = setupServer( ...Object.entries(routes).map(([url, handler]) => { return http.all(url, ({ request, params }) => { calls.push(new TestServerCall(request)); - const handlerType = handler.type; + const response = handler.response; + + if (response === undefined) { + return HttpResponse.json({ error: 'Not Found' }, { status: 404 }); + } + + const handlerType = response.type; + switch (handlerType) { case 'json-value': - return HttpResponse.json(handler.response?.body, { + return HttpResponse.json(response.body, { status: 200, headers: { 'Content-Type': 'application/json', @@ -115,13 +130,40 @@ export function createTestServer( }, }); + case 'stream-chunks': + return new HttpResponse( + convertArrayToReadableStream(response.chunks).pipeThrough( + new TextEncoderStream(), + ), + { + status: 200, + headers: { + 'Content-Type': 'text/event-stream', + 'Cache-Control': 'no-cache', + Connection: 'keep-alive', + ...response.headers, + }, + }, + ); + case 'binary': { - return HttpResponse.arrayBuffer(handler.response?.body, { + return HttpResponse.arrayBuffer(response.body, { status: 200, headers: handler.response?.headers, }); } + case 'error': + return HttpResponse.text(response.body ?? 'Error', { + status: response.status ?? 500, + headers: response.headers, + }); + + case 'empty': + return new HttpResponse(null, { + status: response.status ?? 200, + }); + default: { const _exhaustiveCheck: never = handlerType; throw new Error(`Unknown response type: ${_exhaustiveCheck}`); @@ -139,6 +181,12 @@ export function createTestServer( beforeEach(() => { mswServer.resetHandlers(); + + // set the responses back to the original values + Object.entries(originalRoutes).forEach(([url, handler]) => { + routes[url].response = handler.response; + }); + calls = []; }); diff --git a/packages/replicate/src/replicate-image-model.test.ts b/packages/replicate/src/replicate-image-model.test.ts index 878cf09bd911..011efb66d628 100644 --- a/packages/replicate/src/replicate-image-model.test.ts +++ b/packages/replicate/src/replicate-image-model.test.ts @@ -8,12 +8,10 @@ const model = provider.image('black-forest-labs/flux-schnell'); describe('doGenerate', () => { const server = createTestServer({ - 'https://api.replicate.com/*': { - type: 'json-value', - }, + 'https://api.replicate.com/*': {}, 'https://replicate.delivery/*': { - type: 'binary', response: { + type: 'binary', body: Buffer.from('test-binary-content'), }, }, @@ -23,6 +21,7 @@ describe('doGenerate', () => { output = ['https://replicate.delivery/xezq/abc/out-0.webp'], }: { output?: string | Array } = {}) { server.urls['https://api.replicate.com/*'].response = { + type: 'json-value', body: { id: 's7x1e3dcmhrmc0cm8rbatcneec', model: 'black-forest-labs/flux-schnell',