Skip to content

Commit

Permalink
feat (ai/core): middleware support (#2759)
Browse files Browse the repository at this point in the history
  • Loading branch information
lgrammel committed Sep 6, 2024
1 parent 7ee8d32 commit db61c53
Show file tree
Hide file tree
Showing 23 changed files with 886 additions and 7 deletions.
5 changes: 5 additions & 0 deletions .changeset/many-yaks-relate.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
---
'ai': patch
---

feat (ai/core): middleware support
2 changes: 1 addition & 1 deletion content/docs/03-ai-sdk-core/40-provider-management.mdx
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ description: Learn how to work with multiple providers

# Provider Management

<Note>Provider management is an experimental feature.</Note>
<Note type="warning">Provider management is an experimental feature.</Note>

When you work with multiple providers and models, it is often desirable to manage them in a central place
and access the models through simple string ids.
Expand Down
209 changes: 209 additions & 0 deletions content/docs/03-ai-sdk-core/45-middleware.mdx
Original file line number Diff line number Diff line change
@@ -0,0 +1,209 @@
---
title: Language Model Middleware
description: Learn how to use middleware to enhance the behavior of language models
---

# Language Model Middleware

<Note type="warning">
Language model middleware is an experimental feature.
</Note>

Language model middleware is a way to enhance the behavior of language models
by intercepting and modifying the calls to the language model.

It can be used to add features like guardrails, RAG, caching, and logging
in a language model agnostic way. Such middleware can be developed and
distributed independently from the language models that they are applied to.

## Using Language Model Middleware

You can use language model middleware with the `wrapLanguageModel` function.
It takes a language model and a language model middleware and returns a new
language model that incorporates the middleware.

```ts
import { experimental_wrapLanguageModel as wrapLanguageModel } from 'ai';

const wrappedLanguageModel = wrapLanguageModel({
model: yourModel,
middleware: yourLanguageModelMiddleware,
});
```

The wrapped language model can be used just like any other language model, e.g. in `streamText`:

```ts highlight="2"
const result = await streamText({
model: wrappedLanguageModel,
prompt: 'What cities are in the United States?',
});
```

## Implementing Language Model Middleware

<Note>
Implementing language model middleware is advanced functionality and requires
a solid understanding of the [language model
specification](https://github.com/vercel/ai/blob/main/packages/provider/src/language-model/v1/language-model-v1.ts).
</Note>

You can implement any of the following three function to modify the behavior of the language model:

1. `transformParams`: Transforms the parameters before they are passed to the language model, for both `doGenerate` and `doStream`.
2. `wrapGenerate`: Wraps the `doGenerate` method of the [language model](https://github.com/vercel/ai/blob/main/packages/provider/src/language-model/v1/language-model-v1.ts).
You can modify the parameters, call the language model, and modify the result.
3. `wrapStream`: Wraps the `doStream` method of the [language model](https://github.com/vercel/ai/blob/main/packages/provider/src/language-model/v1/language-model-v1.ts).
You can modify the parameters, call the language model, and modify the result.

Here are some examples of how to implement language model middleware:

## Examples

<Note>
These examples are not meant to be used in production. They are just to show
how you can use middleware to enhance the behavior of language models.
</Note>

### Logging

This example shows how to log the parameters and generated text of a language model call.

```ts
import type {
Experimental_LanguageModelV1Middleware as LanguageModelV1Middleware,
LanguageModelV1StreamPart,
} from 'ai';

export const yourLogMiddleware: LanguageModelV1Middleware = {
wrapGenerate: async ({ doGenerate, params }) => {
console.log('doGenerate called');
console.log(`params: ${JSON.stringify(params, null, 2)}`);

const result = await doGenerate();

console.log('doGenerate finished');
console.log(`generated text: ${result.text}`);

return result;
},

wrapStream: async ({ doStream, params }) => {
console.log('doStream called');
console.log(`params: ${JSON.stringify(params, null, 2)}`);

const { stream, ...rest } = await doStream();

let generatedText = '';

const transformStream = new TransformStream<
LanguageModelV1StreamPart,
LanguageModelV1StreamPart
>({
transform(chunk, controller) {
if (chunk.type === 'text-delta') {
generatedText += chunk.textDelta;
}

controller.enqueue(chunk);
},

flush() {
console.log('doStream finished');
console.log(`generated text: ${generatedText}`);
},
});

return {
stream: stream.pipeThrough(transformStream),
...rest,
};
},
};
```

### Caching

This example shows how to build a simple cache for the generated text of a language model call.

```ts
import type { Experimental_LanguageModelV1Middleware as LanguageModelV1Middleware } from 'ai';

const cache = new Map<string, any>();

export const yourCacheMiddleware: LanguageModelV1Middleware = {
wrapGenerate: async ({ doGenerate, params }) => {
const cacheKey = JSON.stringify(params);

if (cache.has(cacheKey)) {
return cache.get(cacheKey);
}

const result = await doGenerate();

cache.set(cacheKey, result);

return result;
},

// here you would implement the caching logic for streaming
};
```

### Retrieval Augmented Generation (RAG)

This example shows how to use RAG as middleware.

<Note>
Helper functions like `getLastUserMessageText` and `findSources` are not part
of the AI SDK. They are just used in this example to illustrate the concept of
RAG.
</Note>

```ts
import type { Experimental_LanguageModelV1Middleware as LanguageModelV1Middleware } from 'ai';

export const yourRagMiddleware: LanguageModelV1Middleware = {
transformParams: async ({ params }) => {
const lastUserMessageText = getLastUserMessageText({
prompt: params.prompt,
});

if (lastUserMessageText == null) {
return params; // do not use RAG (send unmodified parameters)
}

const instruction =
'Use the following information to answer the question:\n' +
findSources({ text: lastUserMessageText })
.map(chunk => JSON.stringify(chunk))
.join('\n');

return addToLastUserMessage({ params, text: instruction });
},
};
```

### Guardrails

Guard rails are a way to ensure that the generated text of a language model call
is safe and appropriate. This example shows how to use guardrails as middleware.

```ts
import type { Experimental_LanguageModelV1Middleware as LanguageModelV1Middleware } from 'ai';

export const yourGuardrailMiddleware: LanguageModelV1Middleware = {
wrapGenerate: async ({ doGenerate }) => {
const { text, ...rest } = await doGenerate();

// filtering approach, e.g. for PII or other sensitive information:
const cleanedText = text?.replace(/badword/g, '<REDACTED>');

return { text: cleanedText, ...rest };
},

// here you would implement the guardrail logic for streaming
// Note: streaming guardrails are difficult to implement, because
// you do not know the full content of the stream until it's finished.
};
```
Original file line number Diff line number Diff line change
@@ -1,11 +1,11 @@
---
title: experimental_createProviderRegistry
title: createProviderRegistry
description: Registry for managing multiple providers and models (API Reference)
---

# `experimental_createProviderRegistry()`
# `createProviderRegistry()`

<Note>Provider management is an experimental feature.</Note>
<Note type="warning">Provider management is an experimental feature.</Note>

When you work with multiple providers and models, it is often desirable to manage them
in a central place and access the models through simple string ids.
Expand Down
6 changes: 3 additions & 3 deletions content/docs/07-reference/ai-sdk-core/42-custom-provider.mdx
Original file line number Diff line number Diff line change
@@ -1,11 +1,11 @@
---
title: experimental_customProvider
title: customProvider
description: Custom provider that uses models from a different provider (API Reference)
---

# `experimental_customProvider()`
# `customProvider()`

<Note>Provider management is an experimental feature.</Note>
<Note type="warning">Provider management is an experimental feature.</Note>

With a custom provider, you can map ids to any model.
This allows you to set up custom model configurations, alias names, and more.
Expand Down
65 changes: 65 additions & 0 deletions content/docs/07-reference/ai-sdk-core/60-wrap-language-model.mdx
Original file line number Diff line number Diff line change
@@ -0,0 +1,65 @@
---
title: wrapLanguageModel
description: Function for wrapping a language model with middleware (API Reference)
---

# `wrapLanguageModel()`

<Note type="warning">
Language model middleware is an experimental feature.
</Note>

The `experimental_wrapLanguageModel` function provides a way to enhance the behavior of language models
by wrapping them with middleware.
See [Language Model Middleware](/docs/ai-sdk-core/middleware) for more information on middleware.

```ts
import { experimental_wrapLanguageModel as wrapLanguageModel } from 'ai';

const wrappedLanguageModel = wrapLanguageModel({
model: yourModel,
middleware: yourLanguageModelMiddleware,
});
```

## Import

<Snippet
text={`import { experimental_wrapLanguageModel as wrapLanguageModel } from "ai"`}
prompt={false}
/>

## API Signature

### Parameters

<PropertiesTable
content={[
{
name: 'model',
type: 'LanguageModelV1',
description: 'The original LanguageModelV1 instance to be wrapped.',
},
{
name: 'middleware',
type: 'Experimental_LanguageModelV1Middleware',
description: 'The middleware to be applied to the language model.',
},
{
name: 'modelId',
type: 'string',
description:
"Optional custom model ID to override the original model's ID.",
},
{
name: 'providerId',
type: 'string',
description:
"Optional custom provider ID to override the original model's provider.",
},
]}
/>

### Returns

A new `LanguageModelV1` instance with middleware applied.
Original file line number Diff line number Diff line change
@@ -0,0 +1,46 @@
---
title: LanguageModelV1Middleware
description: Middleware for enhancing language model behavior (API Reference)
---

# `LanguageModelV1Middleware`

<Note type="warning">
Language model middleware is an experimental feature.
</Note>

Language model middleware provides a way to enhance the behavior of language models
by intercepting and modifying the calls to the language model. It can be used to add
features like guardrails, RAG, caching, and logging in a language model agnostic way.

See [Language Model Middleware](/docs/ai-sdk-core/middleware) for more information.

## Import

<Snippet
text={`import { Experimental_LanguageModelV1Middleware as LanguageModelV1Middleware } from "ai"`}
prompt={false}
/>

## API Signature

<PropertiesTable
content={[
{
name: 'transformParams',
type: '({ type: "generate" | "stream", params: LanguageModelV1CallOptions }) => Promise<LanguageModelV1CallOptions>',
description:
'Transforms the parameters before they are passed to the language model.',
},
{
name: 'wrapGenerate',
type: '({ doGenerate: DoGenerateFunction, params: LanguageModelV1CallOptions, model: LanguageModelV1 }) => Promise<DoGenerateResult>',
description: 'Wraps the generate operation of the language model.',
},
{
name: 'wrapStream',
type: '({ doStream: DoStreamFunction, params: LanguageModelV1CallOptions, model: LanguageModelV1 }) => Promise<DoStreamResult>',
description: 'Wraps the stream operation of the language model.',
},
]}
/>
28 changes: 28 additions & 0 deletions examples/ai-core/src/middleware/add-to-last-user-message.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
import { LanguageModelV1CallOptions } from 'ai';

export function addToLastUserMessage({
text,
params,
}: {
text: string;
params: LanguageModelV1CallOptions;
}): LanguageModelV1CallOptions {
const { prompt, ...rest } = params;

const lastMessage = prompt.at(-1);

if (lastMessage?.role !== 'user') {
return params;
}

return {
...rest,
prompt: [
...prompt.slice(0, -1),
{
...lastMessage,
content: [{ type: 'text', text }, ...lastMessage.content],
},
],
};
}
Loading

0 comments on commit db61c53

Please sign in to comment.