-
Notifications
You must be signed in to change notification settings - Fork 390
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Actually test the Python translator (#240)
* Add a test for the built-in translator and add the snapshots. * Ensure the repair prompt is actually appended to our messages and update snapshots. * Add uncommitted file.
- Loading branch information
1 parent
66fd7bb
commit 9f644b6
Showing
4 changed files
with
204 additions
and
12 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,137 @@ | ||
# serializer version: 1 | ||
# name: test_translator_with_immediate_pass | ||
list([ | ||
dict({ | ||
'kind': 'CLIENT REQUEST', | ||
'payload': list([ | ||
dict({ | ||
'content': 'Get me stuff.', | ||
'role': 'user', | ||
}), | ||
dict({ | ||
'content': ''' | ||
|
||
You are a service that translates user requests into JSON objects of type "ExampleABC" according to the following TypeScript definitions: | ||
``` | ||
interface ExampleABC { | ||
a: string; | ||
b: boolean; | ||
c: number; | ||
} | ||
|
||
``` | ||
The following is a user request: | ||
''' | ||
Get me stuff. | ||
''' | ||
The following is the user request translated into a JSON object with 2 spaces of indentation and no properties with the value undefined: | ||
|
||
''', | ||
'role': 'user', | ||
}), | ||
]), | ||
}), | ||
dict({ | ||
'kind': 'MODEL RESPONSE', | ||
'payload': '{ "a": "hello", "b": true, "c": 1234 }', | ||
}), | ||
]) | ||
# --- | ||
# name: test_translator_with_single_failure | ||
list([ | ||
dict({ | ||
'kind': 'CLIENT REQUEST', | ||
'payload': list([ | ||
dict({ | ||
'content': 'Get me stuff.', | ||
'role': 'user', | ||
}), | ||
dict({ | ||
'content': ''' | ||
|
||
You are a service that translates user requests into JSON objects of type "ExampleABC" according to the following TypeScript definitions: | ||
``` | ||
interface ExampleABC { | ||
a: string; | ||
b: boolean; | ||
c: number; | ||
} | ||
|
||
``` | ||
The following is a user request: | ||
''' | ||
Get me stuff. | ||
''' | ||
The following is the user request translated into a JSON object with 2 spaces of indentation and no properties with the value undefined: | ||
|
||
''', | ||
'role': 'user', | ||
}), | ||
dict({ | ||
'content': ''' | ||
|
||
The above JSON object is invalid for the following reason: | ||
''' | ||
Validation path `c` failed for value `{"a": "hello", "b": true}` because: | ||
Field required | ||
''' | ||
The following is a revised JSON object: | ||
|
||
''', | ||
'role': 'user', | ||
}), | ||
]), | ||
}), | ||
dict({ | ||
'kind': 'MODEL RESPONSE', | ||
'payload': '{ "a": "hello", "b": true }', | ||
}), | ||
dict({ | ||
'kind': 'CLIENT REQUEST', | ||
'payload': list([ | ||
dict({ | ||
'content': 'Get me stuff.', | ||
'role': 'user', | ||
}), | ||
dict({ | ||
'content': ''' | ||
|
||
You are a service that translates user requests into JSON objects of type "ExampleABC" according to the following TypeScript definitions: | ||
``` | ||
interface ExampleABC { | ||
a: string; | ||
b: boolean; | ||
c: number; | ||
} | ||
|
||
``` | ||
The following is a user request: | ||
''' | ||
Get me stuff. | ||
''' | ||
The following is the user request translated into a JSON object with 2 spaces of indentation and no properties with the value undefined: | ||
|
||
''', | ||
'role': 'user', | ||
}), | ||
dict({ | ||
'content': ''' | ||
|
||
The above JSON object is invalid for the following reason: | ||
''' | ||
Validation path `c` failed for value `{"a": "hello", "b": true}` because: | ||
Field required | ||
''' | ||
The following is a revised JSON object: | ||
|
||
''', | ||
'role': 'user', | ||
}), | ||
]), | ||
}), | ||
dict({ | ||
'kind': 'MODEL RESPONSE', | ||
'payload': '{ "a": "hello", "b": true, "c": 1234 }', | ||
}), | ||
]) | ||
# --- |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,53 @@ | ||
|
||
import asyncio | ||
from dataclasses import dataclass | ||
from typing_extensions import Any, Iterator, Literal, TypedDict, override | ||
import typechat | ||
|
||
class ConvoRecord(TypedDict): | ||
kind: Literal["CLIENT REQUEST", "MODEL RESPONSE"] | ||
payload: str | list[typechat.PromptSection] | ||
|
||
class FixedModel(typechat.TypeChatLanguageModel): | ||
responses: Iterator[str] | ||
conversation: list[ConvoRecord] | ||
|
||
"A model which responds with one of a series of responses." | ||
def __init__(self, responses: list[str]) -> None: | ||
super().__init__() | ||
self.responses = iter(responses) | ||
self.conversation = [] | ||
|
||
@override | ||
async def complete(self, prompt: str | list[typechat.PromptSection]) -> typechat.Result[str]: | ||
self.conversation.append({ "kind": "CLIENT REQUEST", "payload": prompt }) | ||
response = next(self.responses) | ||
self.conversation.append({ "kind": "MODEL RESPONSE", "payload": response }) | ||
return typechat.Success(response) | ||
|
||
@dataclass | ||
class ExampleABC: | ||
a: str | ||
b: bool | ||
c: int | ||
|
||
v = typechat.TypeChatValidator(ExampleABC) | ||
|
||
def test_translator_with_immediate_pass(snapshot: Any): | ||
m = FixedModel([ | ||
'{ "a": "hello", "b": true, "c": 1234 }', | ||
]) | ||
t = typechat.TypeChatJsonTranslator(m, v, ExampleABC) | ||
asyncio.run(t.translate("Get me stuff.")) | ||
|
||
assert m.conversation == snapshot | ||
|
||
def test_translator_with_single_failure(snapshot: Any): | ||
m = FixedModel([ | ||
'{ "a": "hello", "b": true }', | ||
'{ "a": "hello", "b": true, "c": 1234 }', | ||
]) | ||
t = typechat.TypeChatJsonTranslator(m, v, ExampleABC) | ||
asyncio.run(t.translate("Get me stuff.")) | ||
|
||
assert m.conversation == snapshot |
I believe this line shouldn't be here -- the user input now gets added twice, once at the very start, and again as part of the message created by
self._create_request_prompt()
.