Skip to content

Commit

Permalink
Update provider API with provider settings (#1139)
Browse files Browse the repository at this point in the history
  • Loading branch information
diksipav authored Dec 10, 2024
1 parent 70b7195 commit efbbcfa
Show file tree
Hide file tree
Showing 6 changed files with 74 additions and 45 deletions.
7 changes: 4 additions & 3 deletions packages/ai/src/core.ts
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ import {
type QueryContext,
type StreamingMessage,
type RagRequest,
type EmbeddingRequest,
isPromptRequest,
} from "./types.js";
import { getHTTPSCRAMAuth } from "edgedb/dist/httpScram.js";
Expand Down Expand Up @@ -209,7 +210,7 @@ export class EdgeDBAI {
};
}

async generateEmbeddings(inputs: string[], model: string): Promise<number[]> {
async generateEmbeddings(request: EmbeddingRequest): Promise<number[]> {
const response = await (
await this.authenticatedFetch
)("embeddings", {
Expand All @@ -218,8 +219,8 @@ export class EdgeDBAI {
"Content-Type": "application/json",
},
body: JSON.stringify({
model,
input: inputs,
...request,
input: request.inputs,
}),
});

Expand Down
20 changes: 12 additions & 8 deletions packages/ai/src/types.ts
Original file line number Diff line number Diff line change
Expand Up @@ -49,18 +49,15 @@ export interface QueryContext {
max_object_count?: number;
}

interface RagRequestBase {
stream?: boolean;
export interface RagRequestPrompt {
prompt: string;
[key: string]: unknown;
}

export type RagRequestPrompt = RagRequestBase & {
prompt: string;
};

export type RagRequestMessages = RagRequestBase & {
export interface RagRequestMessages {
messages: EdgeDBMessage[];
};
[key: string]: unknown;
}

export type RagRequest = RagRequestPrompt | RagRequestMessages;

Expand Down Expand Up @@ -153,3 +150,10 @@ export type StreamingMessage =
| MessageDelta
| MessageStop
| MessageError;

export interface EmbeddingRequest {
inputs: string[];
model: string;
dimensions?: number;
user?: string;
}
30 changes: 14 additions & 16 deletions packages/vercel-ai-provider/src/edgedb-chat-language-model.ts
Original file line number Diff line number Diff line change
Expand Up @@ -9,13 +9,14 @@ import type {
} from "@ai-sdk/provider";
import {
type ParseResult,
type FetchFunction,
createEventSourceResponseHandler,
createJsonResponseHandler,
postJsonToApi,
generateId,
combineHeaders,
} from "@ai-sdk/provider-utils";
import {
type EdgeDBChatConfig,
type EdgeDBChatModelId,
type EdgeDBChatSettings,
type EdgeDBMessage,
Expand All @@ -30,11 +31,14 @@ import {
import { convertToEdgeDBMessages } from "./convert-to-edgedb-messages";
import { prepareTools } from "./edgedb-prepare-tools";

export interface EdgeDBLanguageModel extends LanguageModelV1 {
withSettings(settings: Partial<EdgeDBChatSettings>): EdgeDBChatLanguageModel;
export interface EdgeDBChatConfig {
provider: string;
fetch: FetchFunction;
// baseURL: string | null;
headers: () => Record<string, string | undefined>;
}

export class EdgeDBChatLanguageModel implements EdgeDBLanguageModel {
export class EdgeDBChatLanguageModel implements LanguageModelV1 {
readonly specificationVersion = "v1";
readonly defaultObjectGenerationMode = "json";
readonly supportsImageUrls = false;
Expand All @@ -58,14 +62,6 @@ export class EdgeDBChatLanguageModel implements EdgeDBLanguageModel {
return this.config.provider;
}

withSettings(settings: Partial<EdgeDBChatSettings>) {
return new EdgeDBChatLanguageModel(
this.modelId,
{ ...this.settings, ...settings },
this.config,
);
}

private getArgs({
// it's not really deprecated since the v2 is not out yet that accepts toolChoice, and tools at the top level
mode,
Expand Down Expand Up @@ -217,8 +213,9 @@ export class EdgeDBChatLanguageModel implements EdgeDBLanguageModel {
const { messages } = args;

const { responseHeaders, value: response } = await postJsonToApi({
url: `rag`,
headers: options.headers,
// url: this.config.baseURL ? `${this.config.baseURL}/rag` : "rag",
url: "rag",
headers: combineHeaders(this.config.headers(), options.headers),
body: {
...args,
context: this.settings.context,
Expand Down Expand Up @@ -266,8 +263,9 @@ export class EdgeDBChatLanguageModel implements EdgeDBLanguageModel {
const { messages } = args;

const { responseHeaders, value: response } = await postJsonToApi({
url: `rag`,
headers: options.headers,
// url: this.config.baseURL ? `${this.config.baseURL}/rag` : "rag",
url: "rag",
headers: combineHeaders(this.config.headers(), options.headers),
body: {
...args,
context: this.settings.context,
Expand Down
7 changes: 0 additions & 7 deletions packages/vercel-ai-provider/src/edgedb-chat-settings.ts
Original file line number Diff line number Diff line change
@@ -1,5 +1,3 @@
import type { FetchFunction } from "@ai-sdk/provider-utils";

export type OpenAIModelId =
| "gpt-4o"
| "gpt-4o-mini"
Expand Down Expand Up @@ -67,11 +65,6 @@ export interface QueryContext {
max_object_count?: number;
}

export interface EdgeDBChatConfig {
provider: string;
fetch: FetchFunction;
}

export interface EdgeDBChatSettings {
context?: QueryContext;
prompt?: Prompt;
Expand Down
11 changes: 9 additions & 2 deletions packages/vercel-ai-provider/src/edgedb-embedding-model.ts
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ import {
createJsonResponseHandler,
type FetchFunction,
postJsonToApi,
combineHeaders,
} from "@ai-sdk/provider-utils";
import { z } from "zod";
import {
Expand All @@ -18,6 +19,8 @@ import { edgedbFailedResponseHandler } from "./edgedb-error";
interface EdgeDBEmbeddingConfig {
provider: string;
fetch?: FetchFunction;
// baseURL: string | null;
headers: () => Record<string, string | undefined>;
}

export class EdgeDBEmbeddingModel implements EmbeddingModelV1<string> {
Expand Down Expand Up @@ -71,12 +74,16 @@ export class EdgeDBEmbeddingModel implements EmbeddingModelV1<string> {
}

const { responseHeaders, value: response } = await postJsonToApi({
url: `embeddings`,
headers,
// url: this.config.baseURL
// ? `${this.config.baseURL}/embeddings`
// : "embeddings",
url: "embeddings",
headers: combineHeaders(this.config.headers(), headers),
body: {
model: this.modelId,
input: values,
encoding_format: "float",
// OpenAI props
dimensions: this.settings.dimensions,
user: this.settings.user,
},
Expand Down
44 changes: 35 additions & 9 deletions packages/vercel-ai-provider/src/edgedb-provider.ts
Original file line number Diff line number Diff line change
Expand Up @@ -7,10 +7,7 @@ import type {
LanguageModelV1,
ProviderV1,
} from "@ai-sdk/provider";
import {
EdgeDBChatLanguageModel,
type EdgeDBLanguageModel,
} from "./edgedb-chat-language-model";
import { EdgeDBChatLanguageModel } from "./edgedb-chat-language-model";
import type {
EdgeDBChatModelId,
EdgeDBChatSettings,
Expand All @@ -24,21 +21,45 @@ import type {
const httpSCRAMAuth = getHTTPSCRAMAuth(cryptoUtils);

export interface EdgeDBProvider extends ProviderV1 {
(modelId: EdgeDBChatModelId | EdgeDBEmbeddingModelId): LanguageModelV1;
(
modelId: EdgeDBChatModelId | EdgeDBEmbeddingModelId,
settings?: EdgeDBChatSettings,
): LanguageModelV1;

languageModel(
modelId: EdgeDBChatModelId,
settings?: EdgeDBChatSettings,
): EdgeDBLanguageModel;
): EdgeDBChatLanguageModel;

textEmbeddingModel: (
modelId: EdgeDBEmbeddingModelId,
settings?: EdgeDBEmbeddingSettings,
) => EmbeddingModelV1<string>;
}

export async function createEdgeDB(client: Client): Promise<EdgeDBProvider> {
export interface EdgeDBProviderSettings {
/**
Use a different URL prefix for API calls, e.g. to use proxy servers.
*/
// baseURL?: string;

/**
Custom headers to include in the requests.
*/
headers?: Record<string, string>;
}

export async function createEdgeDB(
client: Client,
options: EdgeDBProviderSettings = {},
): Promise<EdgeDBProvider> {
const connectConfig = await client.resolveConnectionParams();
// const baseURL = withoutTrailingSlash(options.baseURL) ?? null;

// In case we want to add more things to this in the future
const getHeaders = () => ({
...options.headers,
});

const fetch = await getAuthenticatedFetch(
connectConfig,
Expand All @@ -53,6 +74,7 @@ export async function createEdgeDB(client: Client): Promise<EdgeDBProvider> {
new EdgeDBChatLanguageModel(modelId, settings, {
provider: "edgedb.chat",
fetch,
headers: getHeaders,
});

const createEmbeddingModel = (
Expand All @@ -62,17 +84,21 @@ export async function createEdgeDB(client: Client): Promise<EdgeDBProvider> {
return new EdgeDBEmbeddingModel(modelId, settings, {
provider: "edgedb.embedding",
fetch,
headers: getHeaders,
});
};

const provider = function (modelId: EdgeDBChatModelId) {
const provider = function (
modelId: EdgeDBChatModelId,
settings?: EdgeDBChatSettings,
) {
if (new.target) {
throw new Error(
"The EdgeDB model function cannot be called with the new keyword.",
);
}

return createChatModel(modelId);
return createChatModel(modelId, settings);
};

provider.languageModel = createChatModel;
Expand Down

0 comments on commit efbbcfa

Please sign in to comment.