Skip to content

Commit

Permalink
Python: Add math example (#156)
Browse files Browse the repository at this point in the history
This change adds a simple calculator example which uses JsonProgram for its schema. This is similar to the math example in the typescript implementation. 

There are some changes:
- The JsonProgram validator and translator are included in the example folder. We'll evaluate whether to include them in the main typechat python implementation at a later date.
- The translator takes an API that exposes the functions a model can call. This example shows two ways of providing the schema - either as a TypedDict with Callable values (schema.py) or as a Protocol with instance methods (schemaV2.py).
  • Loading branch information
hillary-mutisya authored Jan 4, 2024
1 parent 65e3a66 commit 7fa71bc
Show file tree
Hide file tree
Showing 5 changed files with 423 additions and 0 deletions.
52 changes: 52 additions & 0 deletions python/examples/math/demo.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,52 @@
import asyncio
import json
import sys
from dotenv import dotenv_values
import schema as math
from typechat import Failure, create_language_model
from program import TypeChatProgramTranslator, TypeChatProgramValidator, JsonProgram, evaluate_json_program


async def handle_call(func: str, args: list[int | float]) -> int | float:
print(f"{func}({json.dumps(args)}) ")
match func:
case "add":
return args[0] + args[1]
case "sub":
return args[0] - args[1]
case "mul":
return args[0] * args[1]
case "div":
return args[0] / args[1]
case "neg":
return -1 * args[0]
case "id":
return args[0]
case _:
raise ValueError(f'Unexpected function name {func}')


async def main():
vals = dotenv_values()
model = create_language_model(vals)
validator = TypeChatProgramValidator(JsonProgram)
translator = TypeChatProgramTranslator(model, validator, math.MathAPI)
print("🧮> ", end="", flush=True)
for line in sys.stdin:
result = await translator.translate(line)
if isinstance(result, Failure):
print("Translation Failed ❌")
print(f"Context: {result.message}")
else:
result = result.value
print("Translation Succeeded! ✅\n")
print("JSON View")
print(json.dumps(result, indent=2))
math_result = await evaluate_json_program(result, handle_call) # type: ignore
print(f"Math Result: {math_result}")

print("\n🧮> ", end="", flush=True)


if __name__ == "__main__":
asyncio.run(main())
165 changes: 165 additions & 0 deletions python/examples/math/program.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,165 @@
from __future__ import annotations
import asyncio
import json
from textwrap import dedent, indent
from typing import TypeVar, Any, Callable, Awaitable, TypedDict, Annotated, NotRequired, override, Sequence

from typechat import (
Failure,
Result,
Success,
TypeChatModel,
TypeChatValidator,
TypeChatTranslator,
python_type_to_typescript_schema,
)
import collections.abc

T = TypeVar("T", covariant=True)

def Doc(s: str) -> str:
return s


program_schema_text = '''
// A program consists of a sequence of function calls that are evaluated in order.
export type Program = {
"@steps": FunctionCall[];
}
// A function call specifies a function name and a list of argument expressions. Arguments may contain
// nested function calls and result references.
export type FunctionCall = {
// Name of the function
"@func": string;
// Arguments for the function, if any
"@args"?: Expression[];
};
// An expression is a JSON value, a function call, or a reference to the result of a preceding expression.
export type Expression = JsonValue | FunctionCall | ResultReference;
// A JSON value is a string, a number, a boolean, null, an object, or an array. Function calls and result
// references can be nested in objects and arrays.
export type JsonValue = string | number | boolean | null | { [x: string]: Expression } | Expression[];
// A result reference represents the value of an expression from a preceding step.
export type ResultReference = {
// Index of the previous expression in the "@steps" array
"@ref": number;
};
'''


ResultReference = TypedDict(
"ResultReference", {"@ref": Annotated[int, Doc("Index of the previous expression in the 'steps' array")]}
)

FunctionCall = TypedDict(
"FunctionCall",
{
"@func": Annotated[str, Doc("Name of the function")],
"@args": NotRequired[Annotated[list["Expression"], Doc("Arguments for the function, if any")]],
},
)

JsonValue = str | int | float | bool | None | dict[str, "Expression"] | list["Expression"]
Expression = JsonValue | FunctionCall | ResultReference

JsonProgram = TypedDict("Program", {"@steps": list[FunctionCall]})


async def evaluate_json_program(program: JsonProgram, onCall: Callable[[str, Sequence[Expression]], Awaitable[Expression]]) -> Expression | Sequence[Expression]:
results: list[Expression] | Expression = []

async def evaluate_array(array: Sequence[Expression]) -> Sequence[Expression]:
return await asyncio.gather(*[evaluate_call(e) for e in array]) # type: ignore

async def evaluate_object(expr: FunctionCall):
if "@ref" in expr:
index = expr["@ref"]
if index < len(results):
return results[index]

elif "@func" in expr and "@args" in expr:
function_name = expr["@func"]
return await onCall(function_name, await evaluate_array(expr["@args"]))

elif isinstance(expr, collections.abc.Sequence):
return await evaluate_array(expr)

else:
raise ValueError("This condition should never hit")

async def evaluate_call(expr: FunctionCall) -> Expression | Sequence[Expression]:
if isinstance(expr, int) or isinstance(expr, float) or isinstance(expr, str):
return expr
return await evaluate_object(expr)

for step in program["@steps"]:
results.append(await evaluate_call(step)) # type: ignore

if len(results) > 0:
return results[-1]
else:
return None


class TypeChatProgramValidator(TypeChatValidator[T]):
def __init__(self, py_type: type[T]):
# the base class init method creates a typeAdapter for T. This operation fails for the JsonProgram type
super().__init__(py_type=Any)

@override
def validate(self, json_text: str) -> Result[T]:
# Pydantic is not able to validate JsonProgram instances. It fails with a recursion error.
# For JsonProgram, simply validate that it has a non-zero number of @steps
# TODO: extend validations
typed_dict = json.loads(json_text)
if "@steps" in typed_dict and isinstance(typed_dict["@steps"], collections.abc.Sequence):
return Success(typed_dict)
else:
return Failure("This is not a valid program. The program must have an array of @steps")


class TypeChatProgramTranslator(TypeChatTranslator[T]):
_api_declaration_str: str

def __init__(self, model: TypeChatModel, validator: TypeChatProgramValidator[T], api_type: type):
super().__init__(model=model, validator=validator, target_type=Any)
conversion_result = python_type_to_typescript_schema(api_type)
self._api_declaration_str = conversion_result.typescript_schema_str

@override
def _create_request_prompt(self, intent: str) -> str:
api_decl_str = indent(self._api_declaration_str, " ")

prompt = F"""
You are a service that translates user requests into programs represented as JSON using the following TypeScript definitions:
```
{program_schema_text}
```
The programs can call functions from the API defined in the following TypeScript definitions:
```
{api_decl_str}
```
The following is a user request:
'''
{intent}
'''
The following is the user request translated into a JSON program object with 2 spaces of indentation and no properties with the value undefined:
"""
prompt = dedent(prompt)
return prompt

@override
def _create_repair_prompt(self, validation_error: str) -> str:
validation_error = indent(validation_error, " ")
prompt = F"""
The JSON program object is invalid for the following reason:
'''
{validation_error}
'''
The following is a revised JSON program object:
"""
return dedent(prompt)
19 changes: 19 additions & 0 deletions python/examples/math/schema.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
from typing import TypedDict, Annotated, Callable


def Doc(s: str) -> str:
return s


class MathAPI(TypedDict):
"""
This is API for a simple calculator
"""

add: Annotated[Callable[[float, float], float], Doc("Add two numbers")]
sub: Annotated[Callable[[float, float], float], Doc("Subtract two numbers")]
mul: Annotated[Callable[[float, float], float], Doc("Multiply two numbers")]
div: Annotated[Callable[[float, float], float], Doc("Divide two numbers")]
neg: Annotated[Callable[[float], float], Doc("Negate a number")]
id: Annotated[Callable[[float], float], Doc("Identity function")]
unknown: Annotated[Callable[[str], float], Doc("Unknown request")]
50 changes: 50 additions & 0 deletions python/examples/math/schemaV2.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,50 @@
from typing import Protocol, runtime_checkable


@runtime_checkable
class MathAPI(Protocol):
"""
This is API for a simple calculator
"""

def add(self, x: float, y: float) -> float:
"""
Add two numbers
"""
...

def sub(self, x: float, y: float) -> float:
"""
Subtract two numbers
"""
...

def mul(self, x: float, y: float) -> float:
"""
Multiply two numbers
"""
...

def div(self, x: float, y: float) -> float:
"""
Divide two numbers
"""
...

def neg(self, x: float) -> float:
"""
Negate a number
"""
...

def id(self, x: float, y: float) -> float:
"""
Identity function
"""
...

def unknown(self, text: str) -> float:
"""
unknown request
"""
...
Loading

0 comments on commit 7fa71bc

Please sign in to comment.