Skip to content

Commit

Permalink
remove unused and add test
Browse files Browse the repository at this point in the history
  • Loading branch information
roodboi committed Jan 2, 2024
1 parent dd32428 commit ab3c0cc
Show file tree
Hide file tree
Showing 9 changed files with 157 additions and 67 deletions.
92 changes: 90 additions & 2 deletions .gitignore
Original file line number Diff line number Diff line change
@@ -1,3 +1,91 @@
node_modules/
# See https://help.github.com/articles/ignoring-files/ for more about ignoring files.

# dependencies
node_modules
/.pnp
.pnp.js

# testing
/coverage

# next.js
.next/
out/

# production
/build

# misc
.DS_Store
*.pem

# debug
npm-debug.log*
yarn-debug.log*
yarn-error.log*

# local env files
.env
dist/
.env.local
.env.development.local
.env.test.local
.env.production.local

# vercel
.vercel
.turbo

node_modules
.pnp
.pnp.js

# testing
coverage

# next.js
.next/
out/
build

# misc
.DS_Store
*.pem

# debug
npm-debug.log*
yarn-debug.log*
yarn-error.log*

# local env files
.env
.env.local
.env.development.local
.env.test.local
.env.production.local

# turbo
.turbo

# vercel
.vercel

dist

.pnp.*
.yarn/*
!.yarn/patches
!.yarn/plugins
!.yarn/releases
!.yarn/sdks
!.yarn/versions

# vim
*.sw*

# env
.env*.local
.envrc


tsconfig.tsbuildinfo

Binary file modified bun.lockb
Binary file not shown.
2 changes: 1 addition & 1 deletion examples/test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ const oai = new OpenAI({

const client = new Instruct({
client: oai,
mode: ""
mode: "FUNCTIONS"
})

const user: User = await client.chat.completions.create({
Expand Down
1 change: 1 addition & 0 deletions package.json
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@
"@changesets/changelog-github": "^0.5.0",
"@changesets/cli": "^2.27.1",
"@ianvs/prettier-plugin-sort-imports": "4.1.0",
"@types/bun": "^1.0.0",
"@types/node": "^20.10.6",
"@typescript-eslint/eslint-plugin": "6.5.0",
"@typescript-eslint/parser": "6.5.0",
Expand Down
53 changes: 12 additions & 41 deletions src/instructor.ts
Original file line number Diff line number Diff line change
Expand Up @@ -12,16 +12,16 @@ import {
} from "./oai/params"
import {
OAIResponseFnArgsParser,
OAIResponseTextParser,
OAIResponseJSONStringParser,
OAIResponseToolArgsParser
} from "./oai/parser"

const MODE_TO_PARSER = {
[MODE.FUNCTIONS]: OAIResponseFnArgsParser,
[MODE.TOOLS]: OAIResponseToolArgsParser,
[MODE.JSON]: OAIResponseTextParser,
[MODE.MD_JSON]: OAIResponseTextParser,
[MODE.JSON_SCHEMA]: OAIResponseTextParser
[MODE.JSON]: OAIResponseJSONStringParser,
[MODE.MD_JSON]: OAIResponseJSONStringParser,
[MODE.JSON_SCHEMA]: OAIResponseJSONStringParser
}

const MODE_TO_PARAMS = {
Expand Down Expand Up @@ -58,24 +58,16 @@ export class Instruct {
* @returns {Promise<any>} The response from the chat completion.
*/
private chatCompletion = async ({
response_model,
max_retries = 3,
...params
}: PatchedChatCompletionCreateParams) => {
let attempts = 0

const functionConfig = this.generateSchemaFunction({
schema: response_model
})
const completionParams = this.buildChatCompletionParams(params)

const makeCompletionCall = async () => {
try {
const completion = await this.client.chat.completions.create({
stream: false,
...params,
...functionConfig
})

const completion = await this.client.chat.completions.create(completionParams)
const response = this.parseOAIResponse(completion)

return response
Expand All @@ -100,10 +92,15 @@ export class Instruct {
return await makeCompletionCallWithRetries()
}

/**
* Builds the chat completion parameters.
* @param {PatchedChatCompletionCreateParams} params - The parameters for chat completion.
* @returns {ChatCompletionCreateParamsNonStreaming} The chat completion parameters.
*/
private buildChatCompletionParams = ({
response_model,
...params
}: PatchedChatCompletionCreateParams) => {
}: PatchedChatCompletionCreateParams): ChatCompletionCreateParamsNonStreaming => {
const { definition } = createSchemaFunction({ schema: response_model })

const paramsForMode = MODE_TO_PARAMS[this.mode](definition, params)
Expand All @@ -125,32 +122,6 @@ export class Instruct {
return parser(response)
}

/**
* Generates a schema function.
* @param {ZodSchema<unknown>} schema - The schema to generate the function from.
* @returns {Object} The generated function configuration.
*/
private generateSchemaFunction({ schema }) {
const { definition } = createSchemaFunction({ schema })

return {
function_call: {
name: definition.name
},
functions: [
{
name: definition.name,
description: definition.description,
parameters: {
type: "object",
properties: definition.parameters,
required: definition.required
}
}
]
}
}

/**
* Public chat interface.
*/
Expand Down
9 changes: 2 additions & 7 deletions src/oai/fns/index.ts
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
import { z } from "zod"
import { zodToJsonSchema } from "zod-to-json-schema"

export type FunctionPayload = {
name: string
Expand All @@ -9,7 +8,6 @@ export type FunctionPayload = {
}

export type FunctionDefinitionParams<P extends z.ZodType<unknown>, R extends z.ZodType<unknown>> = {
paramsSchema: P
jsonSchema?: object
name: string
description: string
Expand Down Expand Up @@ -42,15 +40,13 @@ export type FunctionDefinitionInterface = {
* };
*
* const functionDefinition = createFunctionDefinition({
* paramsSchema,
* name: 'greet',
* description: 'Greets a person.',
* execute
* });
*
*/
function createFunctionDefinition<P extends z.ZodType<unknown>, R extends z.ZodType<unknown>>({
paramsSchema,
jsonSchema,
name,
description,
Expand All @@ -59,8 +55,7 @@ function createFunctionDefinition<P extends z.ZodType<unknown>, R extends z.ZodT
}: FunctionDefinitionParams<P, R>): FunctionDefinitionInterface {
const run = async (params: unknown): Promise<unknown> => {
try {
const validatedParams = paramsSchema.parse(params)
return await execute(validatedParams)
return await execute(params)
} catch (error) {
console.error(`Error executing function ${name}:`, error)
throw error
Expand All @@ -72,7 +67,7 @@ function createFunctionDefinition<P extends z.ZodType<unknown>, R extends z.ZodT
definition: {
name: name,
description: description,
parameters: jsonSchema ?? zodToJsonSchema(paramsSchema),
parameters: jsonSchema,
required
}
}
Expand Down
1 change: 0 additions & 1 deletion src/oai/fns/schema.ts
Original file line number Diff line number Diff line change
Expand Up @@ -109,7 +109,6 @@ export function createSchemaFunction({
}

return createFunctionDefinition({
paramsSchema: schema,
jsonSchema: propertiesMapping,
name,
description,
Expand Down
23 changes: 8 additions & 15 deletions src/oai/parser.ts
Original file line number Diff line number Diff line change
Expand Up @@ -16,10 +16,9 @@ export function OAIResponseTextParser(
) {
const parsedData = typeof data === "string" ? JSON.parse(data) : data

const text =
parsedData.choices?.[0].delta?.content ?? parsedData?.choices[0]?.message?.content ?? ""
const text = parsedData?.choices[0]?.message?.content ?? "{}"

return text
return JSON.parse(text)
}

/**
Expand All @@ -37,12 +36,9 @@ export function OAIResponseFnArgsParser(
) {
const parsedData = typeof data === "string" ? JSON.parse(data) : data

const text =
parsedData.choices?.[0]?.delta?.function_call?.arguments ??
parsedData.choices?.[0]?.message?.function_call?.arguments ??
null
const text = parsedData.choices?.[0]?.message?.function_call?.arguments ?? "{}"

return text
return JSON.parse(text)
}

/**
Expand All @@ -60,12 +56,9 @@ export function OAIResponseToolArgsParser(
) {
const parsedData = typeof data === "string" ? JSON.parse(data) : data

const text =
parsedData.choices?.[0]?.delta?.tool_call?.function?.arguments ??
parsedData.choices?.[0]?.message?.tool_call?.function?.arguments ??
null
const text = parsedData.choices?.[0]?.message?.tool_call?.function?.arguments ?? "{}"

return text
return JSON.parse(text)
}

/**
Expand All @@ -82,7 +75,7 @@ export function OAIResponseJSONStringParser(
| OpenAI.Chat.Completions.ChatCompletion
) {
const parsedData = typeof data === "string" ? JSON.parse(data) : data
const text = parsedData.choices?.[0]?.delta?.content ?? parsedData?.choices[0]?.message?.content
const text = parsedData?.choices[0]?.message?.content ?? "{}"

return text
return JSON.parse(text)
}
43 changes: 43 additions & 0 deletions tests/functions.test.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,43 @@
import { Instruct } from "@/instructor"
import { describe, expect, test } from "bun:test"
import OpenAI from "openai"
import { z } from "zod"

async function extractUser() {
const UserSchema = z.object({
age: z.number(),
name: z.string().refine(name => name.includes(" "), {
message: "Name must contain a space"
})
})

type User = z.infer<typeof UserSchema>

const oai = new OpenAI({
apiKey: process.env.OPENAI_API_KEY ?? undefined,
organization: process.env.OPENAI_ORG_ID ?? undefined
})

const client = new Instruct({
client: oai,
mode: "FUNCTIONS"
})

const user: User = await client.chat.completions.create({
messages: [{ role: "user", content: "Jason Liu is 30 years old" }],
model: "gpt-3.5-turbo",
response_model: UserSchema,
max_retries: 3
})

return user
}

describe("FunctionCall", () => {
test("Should return extracted name and age based on schema", async () => {
const user = await extractUser()

expect(user.name).toEqual("Jason Liu")
expect(user.age).toEqual(30)
})
})

0 comments on commit ab3c0cc

Please sign in to comment.