Skip to content

Commit

Permalink
Anthropic tools (#153)
Browse files Browse the repository at this point in the history
  • Loading branch information
roodboi authored Apr 7, 2024
1 parent 5d7489f commit 76ef059
Show file tree
Hide file tree
Showing 8 changed files with 112 additions and 80 deletions.
5 changes: 5 additions & 0 deletions .changeset/tasty-seas-shave.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
---
"@instructor-ai/instructor": minor
---

updated client types to be more flexible - added tests for latest anthropic updates and llm-polyglot major
10 changes: 8 additions & 2 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -82,7 +82,7 @@ The main class for creating an Instructor client.

**createInstructor**
```typescript
function createInstructor<C extends SupportedInstructorClient = OpenAI>(args: {
function createInstructor<C extends GenericClient | OpenAI>(args: {
client: OpenAILikeClient<C>;
mode: Mode;
debug?: boolean;
Expand All @@ -100,7 +100,13 @@ Returns the extended OpenAI-Like client.

**chat.completions.create**
```typescript
chat.completions.create<T extends z.AnyZodObject>(params: ChatCompletionCreateParamsWithModel<T>): Promise<z.infer<T> & { _meta?: CompletionMeta }>
chat.completions.create<
T extends z.AnyZodObject,
P extends T extends z.AnyZodObject ? ChatCompletionCreateParamsWithModel<T>
: ClientTypeChatCompletionParams<OpenAILikeClient<C>> & { response_model: never }
>(
params: P
): Promise<ReturnTypeBasedOnParams<typeof this.client, P>>
```
When response_model is present in the params, creates a chat completion with structured extraction based on the provided schema - otherwise will proxy back to the provided client.

Expand Down
Binary file modified bun.lockb
Binary file not shown.
7 changes: 4 additions & 3 deletions package.json
Original file line number Diff line number Diff line change
Expand Up @@ -59,13 +59,14 @@
"zod": ">=3.22.4"
},
"devDependencies": {
"@anthropic-ai/sdk": "latest",
"@changesets/changelog-github": "^0.5.0",
"@changesets/cli": "^2.27.1",
"@ianvs/prettier-plugin-sort-imports": "4.1.0",
"@types/bun": "^1.0.0",
"@types/node": "^20.10.6",
"@typescript-eslint/eslint-plugin": "^6.11.0",
"@typescript-eslint/parser": "^6.11.0",
"@typescript-eslint/parser": "^7.5.0",
"@typescript-eslint/eslint-plugin": "^7.5.0",
"eslint-config": "^0.3.0",
"eslint-config-prettier": "^9.0.0",
"eslint-config-turbo": "^1.10.12",
Expand All @@ -74,7 +75,7 @@
"eslint-plugin-only-warn": "^1.1.0",
"eslint-plugin-prettier": "^5.1.2",
"husky": "^8.0.3",
"llm-polyglot": "^0.0.3",
"llm-polyglot": "1.0.0",
"prettier": "latest",
"ts-inference-check": "^0.3.0",
"tsup": "^8.0.1",
Expand Down
10 changes: 7 additions & 3 deletions src/dsl/validator.ts
Original file line number Diff line number Diff line change
Expand Up @@ -2,10 +2,12 @@ import { InstructorClient } from "@/instructor"
import OpenAI from "openai"
import { RefinementCtx, z } from "zod"

import { GenericClient } from ".."

type AsyncSuperRefineFunction = (data: string, ctx: RefinementCtx) => Promise<void>

export const LLMValidator = (
instructor: InstructorClient,
export const LLMValidator = <C extends GenericClient | OpenAI>(
instructor: InstructorClient<C>,
statement: string,
params: Omit<OpenAI.ChatCompletionCreateParams, "messages">
): AsyncSuperRefineFunction => {
Expand Down Expand Up @@ -42,7 +44,9 @@ export const LLMValidator = (
}
}

export const moderationValidator = (client: InstructorClient) => {
export const moderationValidator = <C extends GenericClient | OpenAI>(
client: InstructorClient<C>
) => {
return async (value: string, ctx: z.RefinementCtx) => {
try {
if (!(client instanceof OpenAI)) {
Expand Down
78 changes: 48 additions & 30 deletions src/instructor.ts
Original file line number Diff line number Diff line change
@@ -1,11 +1,11 @@
import {
ChatCompletionCreateParamsWithModel,
GenericChatCompletion,
GenericClient,
InstructorConfig,
LogLevel,
OpenAILikeClient,
ReturnTypeBasedOnParams,
SupportedInstructorClient
ReturnTypeBasedOnParams
} from "@/types"
import OpenAI from "openai"
import { z } from "zod"
Expand All @@ -24,7 +24,7 @@ import { ClientTypeChatCompletionParams, CompletionMeta } from "./types"

const MAX_RETRIES_DEFAULT = 0

class Instructor<C extends SupportedInstructorClient> {
class Instructor<C extends GenericClient | OpenAI> {
readonly client: OpenAILikeClient<C>
readonly mode: Mode
readonly provider: Provider
Expand All @@ -41,10 +41,12 @@ class Instructor<C extends SupportedInstructorClient> {
this.debug = debug

const provider =
this.client?.baseURL.includes(NON_OAI_PROVIDER_URLS.ANYSCALE) ? PROVIDERS.ANYSCALE
: this.client?.baseURL.includes(NON_OAI_PROVIDER_URLS.TOGETHER) ? PROVIDERS.TOGETHER
: this.client?.baseURL.includes(NON_OAI_PROVIDER_URLS.OAI) ? PROVIDERS.OAI
: this.client?.baseURL.includes(NON_OAI_PROVIDER_URLS.ANTHROPIC) ? PROVIDERS.ANTHROPIC
typeof this.client?.baseURL === "string" ?
this.client?.baseURL.includes(NON_OAI_PROVIDER_URLS.ANYSCALE) ? PROVIDERS.ANYSCALE
: this.client?.baseURL.includes(NON_OAI_PROVIDER_URLS.TOGETHER) ? PROVIDERS.TOGETHER
: this.client?.baseURL.includes(NON_OAI_PROVIDER_URLS.OAI) ? PROVIDERS.OAI
: this.client?.baseURL.includes(NON_OAI_PROVIDER_URLS.ANTHROPIC) ? PROVIDERS.ANTHROPIC
: PROVIDERS.OTHER
: PROVIDERS.OTHER

this.provider = provider
Expand Down Expand Up @@ -114,8 +116,8 @@ class Instructor<C extends SupportedInstructorClient> {
let completionParams = withResponseModel({
params: {
...params,
stream: false
},
stream: params.stream ?? false
} as OpenAI.ChatCompletionCreateParams,
mode: this.mode,
response_model
})
Expand All @@ -141,12 +143,18 @@ class Instructor<C extends SupportedInstructorClient> {
}
}

let completion: GenericChatCompletion | null = null
let completion

try {
completion = (await this.client.chat.completions.create(
resolvedParams
)) as GenericChatCompletion
if (this.client.chat?.completions?.create) {
const result = await this.client.chat.completions.create({
...resolvedParams,
stream: false
})
completion = result as GenericChatCompletion<typeof result>
} else {
throw new Error("Unsupported client type")
}
this.log("debug", "raw standard completion response: ", completion)
} catch (error) {
this.log(
Expand Down Expand Up @@ -245,7 +253,7 @@ class Instructor<C extends SupportedInstructorClient> {
params: {
...params,
stream: true
},
} as OpenAI.ChatCompletionCreateParams,
response_model,
mode: this.mode
})
Expand All @@ -260,13 +268,19 @@ class Instructor<C extends SupportedInstructorClient> {

return streamClient.create({
completionPromise: async () => {
const completion = await this.client.chat.completions.create(completionParams)
this.log("debug", "raw stream completion response: ", completion)

return OAIStream({
//TODO: we need to move away from strict openai types - need to cast here but should update to be more flexible
res: completion as AsyncIterable<OpenAI.ChatCompletionChunk>
})
if (this.client.chat?.completions?.create) {
const completion = await this.client.chat.completions.create({
...completionParams,
stream: true
})
this.log("debug", "raw stream completion response: ", completion)

return OAIStream({
res: completion as unknown as AsyncIterable<OpenAI.ChatCompletionChunk>
})
} else {
throw new Error("Unsupported client type")
}
},
response_model
})
Expand All @@ -289,7 +303,7 @@ class Instructor<C extends SupportedInstructorClient> {
create: async <
T extends z.AnyZodObject,
P extends T extends z.AnyZodObject ? ChatCompletionCreateParamsWithModel<T>
: ClientTypeChatCompletionParams<typeof this.client> & { response_model: never }
: ClientTypeChatCompletionParams<OpenAILikeClient<C>> & { response_model: never }
>(
params: P
): Promise<ReturnTypeBasedOnParams<typeof this.client, P>> => {
Expand All @@ -308,20 +322,23 @@ class Instructor<C extends SupportedInstructorClient> {
>
}
} else {
const result =
this.isStandardStream(params) ?
await this.client.chat.completions.create(params)
: await this.client.chat.completions.create(params)
if (this.client.chat?.completions?.create) {
const result =
this.isStandardStream(params) ?
await this.client.chat.completions.create(params)
: await this.client.chat.completions.create(params)

return result as ReturnTypeBasedOnParams<typeof this.client, P>
return result as unknown as ReturnTypeBasedOnParams<OpenAILikeClient<C>, P>
} else {
throw new Error("Completion method is undefined")
}
}
}
}
}
}

export type InstructorClient<C extends SupportedInstructorClient = OpenAI> = Instructor<C> &
OpenAILikeClient<C>
export type InstructorClient<C extends GenericClient | OpenAI> = Instructor<C> & OpenAILikeClient<C>

/**
* Creates an instance of the `Instructor` class.
Expand All @@ -344,7 +361,7 @@ export type InstructorClient<C extends SupportedInstructorClient = OpenAI> = Ins
* @param args
* @returns
*/
export default function <C extends SupportedInstructorClient = OpenAI>(args: {
export default function createInstructor<C extends GenericClient | OpenAI>(args: {
client: OpenAILikeClient<C>
mode: Mode
debug?: boolean
Expand All @@ -355,6 +372,7 @@ export default function <C extends SupportedInstructorClient = OpenAI>(args: {
if (prop in target) {
return Reflect.get(target, prop, receiver)
}

return Reflect.get(target.client, prop, receiver)
}
})
Expand Down
58 changes: 30 additions & 28 deletions src/types/index.ts
Original file line number Diff line number Diff line change
Expand Up @@ -7,59 +7,62 @@ import {
type ResponseModel as ZResponseModel
} from "zod-stream"

export type GenericCreateParams = Omit<Partial<OpenAI.ChatCompletionCreateParams>, "model"> & {
export type GenericCreateParams<M = unknown> = Omit<
Partial<OpenAI.ChatCompletionCreateParams>,
"model" | "messages"
> & {
model: string
messages: OpenAI.ChatCompletionCreateParams["messages"]
messages: M[]
stream?: boolean
max_tokens?: number | null
[key: string]: unknown
}

export type GenericChatCompletion = Partial<OpenAI.Chat.Completions.ChatCompletion> & {
export type GenericChatCompletion<T = unknown> = Partial<OpenAI.Chat.Completions.ChatCompletion> & {
[key: string]: unknown
choices?: T
}

export type GenericChatCompletionStream = AsyncIterable<
Partial<OpenAI.Chat.Completions.ChatCompletionChunk> & { [key: string]: unknown }
export type GenericChatCompletionStream<T = unknown> = AsyncIterable<
Partial<OpenAI.Chat.Completions.ChatCompletionChunk> & {
[key: string]: unknown
choices?: T
}
>

export type GenericClient<
P extends GenericCreateParams = GenericCreateParams,
Completion = GenericChatCompletion,
Chunk = GenericChatCompletionStream
> = {
baseURL: string
chat: {
completions: {
create: (params: P) => CreateMethodReturnType<P, Completion, Chunk>
export type GenericClient = {
[key: string]: unknown
baseURL?: string
chat?: {
completions?: {
create?: (params: GenericCreateParams) => Promise<unknown>
}
}
}

export type CreateMethodReturnType<
P extends GenericCreateParams,
Completion = GenericChatCompletion,
Chunk = GenericChatCompletionStream
> = P extends { stream: true } ? Promise<Chunk> : Promise<Completion>

export type ClientTypeChatCompletionParams<C extends SupportedInstructorClient> =
export type ClientTypeChatCompletionParams<C> =
C extends OpenAI ? OpenAI.ChatCompletionCreateParams : GenericCreateParams

export type ClientType<C extends SupportedInstructorClient> =
C extends OpenAI ? "openai" : "generic"
export type ClientType<C> =
C extends OpenAI ? "openai"
: C extends GenericClient ? "generic"
: never

export type OpenAILikeClient<C extends SupportedInstructorClient> =
ClientType<C> extends "openai" ? OpenAI : C
export type OpenAILikeClient<C> = C extends OpenAI ? OpenAI : C & GenericClient

export type SupportedInstructorClient = GenericClient | OpenAI

export type LogLevel = "debug" | "info" | "warn" | "error"

export type CompletionMeta = Partial<ZCompletionMeta> & {
usage?: OpenAI.CompletionUsage
}

export type Mode = ZMode

export type ResponseModel<T extends z.AnyZodObject> = ZResponseModel<T>

export interface InstructorConfig<C extends SupportedInstructorClient> {
export interface InstructorConfig<C> {
client: OpenAILikeClient<C>
mode: Mode
debug?: boolean
Expand Down Expand Up @@ -87,5 +90,4 @@ export type ReturnTypeBasedOnParams<C, P> =
P extends { stream: true } ?
Stream<OpenAI.Chat.Completions.ChatCompletionChunk>
: OpenAI.Chat.Completions.ChatCompletion
: P extends { stream: true } ? Promise<GenericChatCompletionStream>
: Promise<GenericChatCompletion>
: Promise<unknown>
Loading

0 comments on commit 76ef059

Please sign in to comment.