Skip to content

Commit

Permalink
Actually test the Python translator (#240)
Browse files Browse the repository at this point in the history
* Add a test for the built-in translator and add the snapshots.

* Ensure the repair prompt is actually appended to our messages and update snapshots.

* Add uncommitted file.
  • Loading branch information
DanielRosenwasser authored Apr 20, 2024
1 parent 66fd7bb commit 9f644b6
Show file tree
Hide file tree
Showing 4 changed files with 204 additions and 12 deletions.
4 changes: 2 additions & 2 deletions python/examples/healthData/translator.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,8 +27,8 @@ def __init__(
self._additional_agent_instructions = additional_agent_instructions

@override
async def translate(self, request: str, *, prompt_preamble: str | list[PromptSection] | None = None) -> Result[T]:
result = await super().translate(request=request, prompt_preamble=prompt_preamble)
async def translate(self, input: str, *, prompt_preamble: str | list[PromptSection] | None = None) -> Result[T]:
result = await super().translate(input=input, prompt_preamble=prompt_preamble)
if not isinstance(result, Failure):
self._chat_history.append(ChatMessage(source="assistant", body=result.value))
return result
Expand Down
22 changes: 12 additions & 10 deletions python/src/typechat/_internal/translator.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,31 +49,33 @@ def __init__(
self._type_name = conversion_result.typescript_type_reference
self._schema_str = conversion_result.typescript_schema_str

async def translate(self, request: str, *, prompt_preamble: str | list[PromptSection] | None = None) -> Result[T]:
async def translate(self, input: str, *, prompt_preamble: str | list[PromptSection] | None = None) -> Result[T]:
"""
Translates a natural language request into an object of type `T`. If the JSON object returned by
the language model fails to validate, repair attempts will be made up until `_max_repair_attempts`.
The prompt for the subsequent attempts will include the diagnostics produced for the prior attempt.
This often helps produce a valid instance.
Args:
request: A natural language request.
input: A natural language request.
prompt_preamble: An optional string or list of prompt sections to prepend to the generated prompt.\
If a string is given, it is converted to a single "user" role prompt section.
"""
request = self._create_request_prompt(request)

prompt: str | list[PromptSection]
if prompt_preamble is None:
prompt = request
else:
messages: list[PromptSection] = []

messages.append({"role": "user", "content": input})

This comment has been minimized.

Copy link
@gvanrossum

gvanrossum Apr 22, 2024

Contributor

I believe this line shouldn't be here -- the user input now gets added twice, once at the very start, and again as part of the message created by self._create_request_prompt().

if prompt_preamble:
if isinstance(prompt_preamble, str):
prompt_preamble = [{"role": "user", "content": prompt_preamble}]
prompt = [*prompt_preamble, {"role": "user", "content": request}]
else:
messages.extend(prompt_preamble)

messages.append({"role": "user", "content": self._create_request_prompt(input)})

num_repairs_attempted = 0
while True:
completion_response = await self.model.complete(prompt)
completion_response = await self.model.complete(messages)
if isinstance(completion_response, Failure):
return completion_response

Expand All @@ -93,7 +95,7 @@ async def translate(self, request: str, *, prompt_preamble: str | list[PromptSec
if num_repairs_attempted >= self._max_repair_attempts:
return Failure(error_message)
num_repairs_attempted += 1
request = f"{text_response}\n{self._create_repair_prompt(error_message)}"
messages.append({"role": "user", "content": self._create_repair_prompt(error_message)})

def _create_request_prompt(self, intent: str) -> str:
prompt = f"""
Expand Down
137 changes: 137 additions & 0 deletions python/tests/__snapshots__/test_translator.ambr
Original file line number Diff line number Diff line change
@@ -0,0 +1,137 @@
# serializer version: 1
# name: test_translator_with_immediate_pass
list([
dict({
'kind': 'CLIENT REQUEST',
'payload': list([
dict({
'content': 'Get me stuff.',
'role': 'user',
}),
dict({
'content': '''

You are a service that translates user requests into JSON objects of type "ExampleABC" according to the following TypeScript definitions:
```
interface ExampleABC {
a: string;
b: boolean;
c: number;
}

```
The following is a user request:
'''
Get me stuff.
'''
The following is the user request translated into a JSON object with 2 spaces of indentation and no properties with the value undefined:

''',
'role': 'user',
}),
]),
}),
dict({
'kind': 'MODEL RESPONSE',
'payload': '{ "a": "hello", "b": true, "c": 1234 }',
}),
])
# ---
# name: test_translator_with_single_failure
list([
dict({
'kind': 'CLIENT REQUEST',
'payload': list([
dict({
'content': 'Get me stuff.',
'role': 'user',
}),
dict({
'content': '''

You are a service that translates user requests into JSON objects of type "ExampleABC" according to the following TypeScript definitions:
```
interface ExampleABC {
a: string;
b: boolean;
c: number;
}

```
The following is a user request:
'''
Get me stuff.
'''
The following is the user request translated into a JSON object with 2 spaces of indentation and no properties with the value undefined:

''',
'role': 'user',
}),
dict({
'content': '''

The above JSON object is invalid for the following reason:
'''
Validation path `c` failed for value `{"a": "hello", "b": true}` because:
Field required
'''
The following is a revised JSON object:

''',
'role': 'user',
}),
]),
}),
dict({
'kind': 'MODEL RESPONSE',
'payload': '{ "a": "hello", "b": true }',
}),
dict({
'kind': 'CLIENT REQUEST',
'payload': list([
dict({
'content': 'Get me stuff.',
'role': 'user',
}),
dict({
'content': '''

You are a service that translates user requests into JSON objects of type "ExampleABC" according to the following TypeScript definitions:
```
interface ExampleABC {
a: string;
b: boolean;
c: number;
}

```
The following is a user request:
'''
Get me stuff.
'''
The following is the user request translated into a JSON object with 2 spaces of indentation and no properties with the value undefined:

''',
'role': 'user',
}),
dict({
'content': '''

The above JSON object is invalid for the following reason:
'''
Validation path `c` failed for value `{"a": "hello", "b": true}` because:
Field required
'''
The following is a revised JSON object:

''',
'role': 'user',
}),
]),
}),
dict({
'kind': 'MODEL RESPONSE',
'payload': '{ "a": "hello", "b": true, "c": 1234 }',
}),
])
# ---
53 changes: 53 additions & 0 deletions python/tests/test_translator.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,53 @@

import asyncio
from dataclasses import dataclass
from typing_extensions import Any, Iterator, Literal, TypedDict, override
import typechat

class ConvoRecord(TypedDict):
kind: Literal["CLIENT REQUEST", "MODEL RESPONSE"]
payload: str | list[typechat.PromptSection]

class FixedModel(typechat.TypeChatLanguageModel):
responses: Iterator[str]
conversation: list[ConvoRecord]

"A model which responds with one of a series of responses."
def __init__(self, responses: list[str]) -> None:
super().__init__()
self.responses = iter(responses)
self.conversation = []

@override
async def complete(self, prompt: str | list[typechat.PromptSection]) -> typechat.Result[str]:
self.conversation.append({ "kind": "CLIENT REQUEST", "payload": prompt })
response = next(self.responses)
self.conversation.append({ "kind": "MODEL RESPONSE", "payload": response })
return typechat.Success(response)

@dataclass
class ExampleABC:
a: str
b: bool
c: int

v = typechat.TypeChatValidator(ExampleABC)

def test_translator_with_immediate_pass(snapshot: Any):
m = FixedModel([
'{ "a": "hello", "b": true, "c": 1234 }',
])
t = typechat.TypeChatJsonTranslator(m, v, ExampleABC)
asyncio.run(t.translate("Get me stuff."))

assert m.conversation == snapshot

def test_translator_with_single_failure(snapshot: Any):
m = FixedModel([
'{ "a": "hello", "b": true }',
'{ "a": "hello", "b": true, "c": 1234 }',
])
t = typechat.TypeChatJsonTranslator(m, v, ExampleABC)
asyncio.run(t.translate("Get me stuff."))

assert m.conversation == snapshot

0 comments on commit 9f644b6

Please sign in to comment.