Skip to content

Commit

Permalink
Merge pull request #20 from jxnl/test-modes
Browse files Browse the repository at this point in the history
Test Modes
  • Loading branch information
jxnl authored Jan 3, 2024
2 parents 7e63838 + 98851ba commit d2b773e
Show file tree
Hide file tree
Showing 4 changed files with 80 additions and 7 deletions.
2 changes: 1 addition & 1 deletion src/constants/modes.ts
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,6 @@ export const MODE = {
JSON: "JSON",
MD_JSON: "MD_JSON",
JSON_SCHEMA: "JSON_SCHEMA"
}
} as const

export type MODE = keyof typeof MODE
20 changes: 16 additions & 4 deletions src/oai/params.ts
Original file line number Diff line number Diff line change
Expand Up @@ -11,13 +11,24 @@ export function OAIBuildFunctionParams(definition, params) {
}

export function OAIBuildToolFunctionParams(definition, params) {
const { name, ...definitionParams } = definition

return {
...params,
tool_choice: {
type: "function",
function: { name: definition.name }
function: { name }
},
tools: [...(params?.tools ?? []), definition]
tools: [
{
type: "function",
function: {
name,
parameters: definitionParams
}
},
...(params?.tools ?? [])
]
}
}

Expand All @@ -27,7 +38,8 @@ export function OAIBuildMessageBasedParams(definition, params, mode) {
response_format: { type: "json_object" }
},
[MODE.JSON_SCHEMA]: {
response_format: { type: "json_object", schema: definition }
//TODO: not sure what is different about this mode - the OAI sdk doesnt accept a schema here
response_format: { type: "json_object" }
}
}

Expand All @@ -39,7 +51,7 @@ export function OAIBuildMessageBasedParams(definition, params, mode) {
messages: [
...(params?.messages ?? []),
{
role: "SYSTEM",
role: "system",
content: `
Given a user prompt, you will return fully valid JSON based on the following description and schema.
You will return no other prose. You will take into account the descriptions for each paramater within the schema
Expand Down
3 changes: 1 addition & 2 deletions src/oai/parser.ts
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,6 @@ export function OAIResponseFnArgsParser(
| OpenAI.Chat.Completions.ChatCompletion
) {
const parsedData = typeof data === "string" ? JSON.parse(data) : data

const text = parsedData.choices?.[0]?.message?.function_call?.arguments ?? "{}"

return JSON.parse(text)
Expand All @@ -56,7 +55,7 @@ export function OAIResponseToolArgsParser(
) {
const parsedData = typeof data === "string" ? JSON.parse(data) : data

const text = parsedData.choices?.[0]?.message?.tool_call?.function?.arguments ?? "{}"
const text = parsedData.choices?.[0]?.message?.tool_calls?.[0]?.function?.arguments ?? "{}"

return JSON.parse(text)
}
Expand Down
62 changes: 62 additions & 0 deletions tests/mode.test.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,62 @@
import Instructor from "@/instructor"
import { describe, expect, test } from "bun:test"
import OpenAI from "openai"
import { z } from "zod"

import { MODE } from "@/constants/modes"

const models_latest = ["gpt-3.5-turbo-1106", "gpt-4-1106-preview"]
const models_old = ["gpt-3.5-turbo", "gpt-4"]

const createTestCases = (): { model: string; mode: MODE }[] => {
const { FUNCTIONS, ...rest } = MODE
const modes = Object.values(rest)

return [
...models_latest.flatMap(model => modes.map(mode => ({ model, mode }))),
...models_old.flatMap(model => ({ model, mode: FUNCTIONS }))
]
}

const UserSchema = z.object({
age: z.number(),
name: z.string().refine(name => name.includes(" "), {
message: "Name must contain a space"
})
})

type User = z.infer<typeof UserSchema>

async function extractUser(model: string, mode: MODE) {
const oai = new OpenAI({
apiKey: process.env.OPENAI_API_KEY ?? undefined,
organization: process.env.OPENAI_ORG_ID ?? undefined
})

const client = Instructor({
client: oai,
mode: mode
})

const user: User = await client.chat.completions.create({
messages: [{ role: "user", content: "Jason Liu is 30 years old" }],
model: model,
response_model: UserSchema,
max_retries: 3
})

return user
}

describe("Modes", async () => {
const testCases = createTestCases()

for await (const { model, mode } of testCases) {
test(`Should return extracted name and age for model ${model} and mode ${mode}`, async () => {
const user = await extractUser(model, mode)

expect(user.name).toEqual("Jason Liu")
expect(user.age).toEqual(30)
})
}
})

0 comments on commit d2b773e

Please sign in to comment.