-
Notifications
You must be signed in to change notification settings - Fork 390
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
This change adds a simple calculator example which uses JsonProgram for its schema. This is similar to the math example in the typescript implementation. There are some changes: - The JsonProgram validator and translator are included in the example folder. We'll evaluate whether to include them in the main typechat python implementation at a later date. - The translator takes an API that exposes the functions a model can call. This example shows two ways of providing the schema - either as a TypedDict with Callable values (schema.py) or as a Protocol with instance methods (schemaV2.py).
- Loading branch information
1 parent
65e3a66
commit 7fa71bc
Showing
5 changed files
with
423 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,52 @@ | ||
import asyncio | ||
import json | ||
import sys | ||
from dotenv import dotenv_values | ||
import schema as math | ||
from typechat import Failure, create_language_model | ||
from program import TypeChatProgramTranslator, TypeChatProgramValidator, JsonProgram, evaluate_json_program | ||
|
||
|
||
async def handle_call(func: str, args: list[int | float]) -> int | float: | ||
print(f"{func}({json.dumps(args)}) ") | ||
match func: | ||
case "add": | ||
return args[0] + args[1] | ||
case "sub": | ||
return args[0] - args[1] | ||
case "mul": | ||
return args[0] * args[1] | ||
case "div": | ||
return args[0] / args[1] | ||
case "neg": | ||
return -1 * args[0] | ||
case "id": | ||
return args[0] | ||
case _: | ||
raise ValueError(f'Unexpected function name {func}') | ||
|
||
|
||
async def main(): | ||
vals = dotenv_values() | ||
model = create_language_model(vals) | ||
validator = TypeChatProgramValidator(JsonProgram) | ||
translator = TypeChatProgramTranslator(model, validator, math.MathAPI) | ||
print("🧮> ", end="", flush=True) | ||
for line in sys.stdin: | ||
result = await translator.translate(line) | ||
if isinstance(result, Failure): | ||
print("Translation Failed ❌") | ||
print(f"Context: {result.message}") | ||
else: | ||
result = result.value | ||
print("Translation Succeeded! ✅\n") | ||
print("JSON View") | ||
print(json.dumps(result, indent=2)) | ||
math_result = await evaluate_json_program(result, handle_call) # type: ignore | ||
print(f"Math Result: {math_result}") | ||
|
||
print("\n🧮> ", end="", flush=True) | ||
|
||
|
||
if __name__ == "__main__": | ||
asyncio.run(main()) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,165 @@ | ||
from __future__ import annotations | ||
import asyncio | ||
import json | ||
from textwrap import dedent, indent | ||
from typing import TypeVar, Any, Callable, Awaitable, TypedDict, Annotated, NotRequired, override, Sequence | ||
|
||
from typechat import ( | ||
Failure, | ||
Result, | ||
Success, | ||
TypeChatModel, | ||
TypeChatValidator, | ||
TypeChatTranslator, | ||
python_type_to_typescript_schema, | ||
) | ||
import collections.abc | ||
|
||
T = TypeVar("T", covariant=True) | ||
|
||
def Doc(s: str) -> str: | ||
return s | ||
|
||
|
||
program_schema_text = ''' | ||
// A program consists of a sequence of function calls that are evaluated in order. | ||
export type Program = { | ||
"@steps": FunctionCall[]; | ||
} | ||
// A function call specifies a function name and a list of argument expressions. Arguments may contain | ||
// nested function calls and result references. | ||
export type FunctionCall = { | ||
// Name of the function | ||
"@func": string; | ||
// Arguments for the function, if any | ||
"@args"?: Expression[]; | ||
}; | ||
// An expression is a JSON value, a function call, or a reference to the result of a preceding expression. | ||
export type Expression = JsonValue | FunctionCall | ResultReference; | ||
// A JSON value is a string, a number, a boolean, null, an object, or an array. Function calls and result | ||
// references can be nested in objects and arrays. | ||
export type JsonValue = string | number | boolean | null | { [x: string]: Expression } | Expression[]; | ||
// A result reference represents the value of an expression from a preceding step. | ||
export type ResultReference = { | ||
// Index of the previous expression in the "@steps" array | ||
"@ref": number; | ||
}; | ||
''' | ||
|
||
|
||
ResultReference = TypedDict( | ||
"ResultReference", {"@ref": Annotated[int, Doc("Index of the previous expression in the 'steps' array")]} | ||
) | ||
|
||
FunctionCall = TypedDict( | ||
"FunctionCall", | ||
{ | ||
"@func": Annotated[str, Doc("Name of the function")], | ||
"@args": NotRequired[Annotated[list["Expression"], Doc("Arguments for the function, if any")]], | ||
}, | ||
) | ||
|
||
JsonValue = str | int | float | bool | None | dict[str, "Expression"] | list["Expression"] | ||
Expression = JsonValue | FunctionCall | ResultReference | ||
|
||
JsonProgram = TypedDict("Program", {"@steps": list[FunctionCall]}) | ||
|
||
|
||
async def evaluate_json_program(program: JsonProgram, onCall: Callable[[str, Sequence[Expression]], Awaitable[Expression]]) -> Expression | Sequence[Expression]: | ||
results: list[Expression] | Expression = [] | ||
|
||
async def evaluate_array(array: Sequence[Expression]) -> Sequence[Expression]: | ||
return await asyncio.gather(*[evaluate_call(e) for e in array]) # type: ignore | ||
|
||
async def evaluate_object(expr: FunctionCall): | ||
if "@ref" in expr: | ||
index = expr["@ref"] | ||
if index < len(results): | ||
return results[index] | ||
|
||
elif "@func" in expr and "@args" in expr: | ||
function_name = expr["@func"] | ||
return await onCall(function_name, await evaluate_array(expr["@args"])) | ||
|
||
elif isinstance(expr, collections.abc.Sequence): | ||
return await evaluate_array(expr) | ||
|
||
else: | ||
raise ValueError("This condition should never hit") | ||
|
||
async def evaluate_call(expr: FunctionCall) -> Expression | Sequence[Expression]: | ||
if isinstance(expr, int) or isinstance(expr, float) or isinstance(expr, str): | ||
return expr | ||
return await evaluate_object(expr) | ||
|
||
for step in program["@steps"]: | ||
results.append(await evaluate_call(step)) # type: ignore | ||
|
||
if len(results) > 0: | ||
return results[-1] | ||
else: | ||
return None | ||
|
||
|
||
class TypeChatProgramValidator(TypeChatValidator[T]): | ||
def __init__(self, py_type: type[T]): | ||
# the base class init method creates a typeAdapter for T. This operation fails for the JsonProgram type | ||
super().__init__(py_type=Any) | ||
|
||
@override | ||
def validate(self, json_text: str) -> Result[T]: | ||
# Pydantic is not able to validate JsonProgram instances. It fails with a recursion error. | ||
# For JsonProgram, simply validate that it has a non-zero number of @steps | ||
# TODO: extend validations | ||
typed_dict = json.loads(json_text) | ||
if "@steps" in typed_dict and isinstance(typed_dict["@steps"], collections.abc.Sequence): | ||
return Success(typed_dict) | ||
else: | ||
return Failure("This is not a valid program. The program must have an array of @steps") | ||
|
||
|
||
class TypeChatProgramTranslator(TypeChatTranslator[T]): | ||
_api_declaration_str: str | ||
|
||
def __init__(self, model: TypeChatModel, validator: TypeChatProgramValidator[T], api_type: type): | ||
super().__init__(model=model, validator=validator, target_type=Any) | ||
conversion_result = python_type_to_typescript_schema(api_type) | ||
self._api_declaration_str = conversion_result.typescript_schema_str | ||
|
||
@override | ||
def _create_request_prompt(self, intent: str) -> str: | ||
api_decl_str = indent(self._api_declaration_str, " ") | ||
|
||
prompt = F""" | ||
You are a service that translates user requests into programs represented as JSON using the following TypeScript definitions: | ||
``` | ||
{program_schema_text} | ||
``` | ||
The programs can call functions from the API defined in the following TypeScript definitions: | ||
``` | ||
{api_decl_str} | ||
``` | ||
The following is a user request: | ||
''' | ||
{intent} | ||
''' | ||
The following is the user request translated into a JSON program object with 2 spaces of indentation and no properties with the value undefined: | ||
""" | ||
prompt = dedent(prompt) | ||
return prompt | ||
|
||
@override | ||
def _create_repair_prompt(self, validation_error: str) -> str: | ||
validation_error = indent(validation_error, " ") | ||
prompt = F""" | ||
The JSON program object is invalid for the following reason: | ||
''' | ||
{validation_error} | ||
''' | ||
The following is a revised JSON program object: | ||
""" | ||
return dedent(prompt) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,19 @@ | ||
from typing import TypedDict, Annotated, Callable | ||
|
||
|
||
def Doc(s: str) -> str: | ||
return s | ||
|
||
|
||
class MathAPI(TypedDict): | ||
""" | ||
This is API for a simple calculator | ||
""" | ||
|
||
add: Annotated[Callable[[float, float], float], Doc("Add two numbers")] | ||
sub: Annotated[Callable[[float, float], float], Doc("Subtract two numbers")] | ||
mul: Annotated[Callable[[float, float], float], Doc("Multiply two numbers")] | ||
div: Annotated[Callable[[float, float], float], Doc("Divide two numbers")] | ||
neg: Annotated[Callable[[float], float], Doc("Negate a number")] | ||
id: Annotated[Callable[[float], float], Doc("Identity function")] | ||
unknown: Annotated[Callable[[str], float], Doc("Unknown request")] |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,50 @@ | ||
from typing import Protocol, runtime_checkable | ||
|
||
|
||
@runtime_checkable | ||
class MathAPI(Protocol): | ||
""" | ||
This is API for a simple calculator | ||
""" | ||
|
||
def add(self, x: float, y: float) -> float: | ||
""" | ||
Add two numbers | ||
""" | ||
... | ||
|
||
def sub(self, x: float, y: float) -> float: | ||
""" | ||
Subtract two numbers | ||
""" | ||
... | ||
|
||
def mul(self, x: float, y: float) -> float: | ||
""" | ||
Multiply two numbers | ||
""" | ||
... | ||
|
||
def div(self, x: float, y: float) -> float: | ||
""" | ||
Divide two numbers | ||
""" | ||
... | ||
|
||
def neg(self, x: float) -> float: | ||
""" | ||
Negate a number | ||
""" | ||
... | ||
|
||
def id(self, x: float, y: float) -> float: | ||
""" | ||
Identity function | ||
""" | ||
... | ||
|
||
def unknown(self, text: str) -> float: | ||
""" | ||
unknown request | ||
""" | ||
... |
Oops, something went wrong.