Skip to content

Commit

Permalink
Support user-defined validation
Browse files Browse the repository at this point in the history
  • Loading branch information
katsumiok committed Feb 15, 2024
1 parent f5e6312 commit ee5ebad
Show file tree
Hide file tree
Showing 4 changed files with 45 additions and 5 deletions.
13 changes: 12 additions & 1 deletion pyaskit/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,15 @@
import warnings

chat_function = chat_with_retry
history = []


def get_history():
return history


def clear_history():
history.clear()


def set_chat_function(func):
Expand All @@ -13,7 +22,9 @@ def set_chat_function(func):


def chat(messages):
return chat_function(messages)
content, completion = chat_function(messages)
history.append(messages + [{"role": "assistant", "content": content}])
return content, completion


def use_llama(
Expand Down
14 changes: 13 additions & 1 deletion pyaskit/dialog.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,9 +86,20 @@ def query(
var_map: dict,
return_type,
training_examples: ExampleType,
validator
):
messages = make_messages(task, return_type, var_map, training_examples)
return ask_and_parse(return_type, messages)
data, reason, errors, completion = ask_and_parse(return_type, messages)
for _ in range(10):
if validator is None:
break
if validator.is_valid(data):
break
new_messages = messages.copy()
new_messages.append({"role": "assistant", "content": make_answer(data)})
new_messages.append({"role": "user", "content": f"Correct the answer in JSON again to solve the following error: {validator.feedback}\nProvide the whole answer."})
data, reason, errors, completion = ask_and_parse(return_type, new_messages)
return data, reason, errors, completion


def chat_raw(return_type, messages):
Expand Down Expand Up @@ -116,6 +127,7 @@ def ask_and_parse(return_type, messages):
messages.append({"role": "user", "content": s})
retry = True
errors.append(str(e))
return None, "Retry limit exceeded", errors, completion


def generate_schema(return_type):
Expand Down
22 changes: 19 additions & 3 deletions pyaskit/function.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
from .path import add_to_sys_path
from .example import ExampleType, check_examples
from .logging_config import setup_logger
from .core import get_history, clear_history


logger = setup_logger(__name__)
Expand Down Expand Up @@ -38,15 +39,25 @@ def __init__(
self._errors: List[str] = []
self._completion = None
self._recompilation_count = 0
self._validator = None
self._history = []

def set_validator(self, validator):
self._validator = validator

def __call__(self, *args, **kwargs):
converted_template = convert_template(self.template)
variableMap = {}
self.check_args(args, kwargs, self.variables, variableMap)

clear_history()
result, self._reason, self._errors, self._completion = query(
converted_template, variableMap, self.return_type, self.training_examples
converted_template,
variableMap,
self.return_type,
self.training_examples,
self._validator,
)
self._history = get_history()
return result

@property
Expand All @@ -65,6 +76,10 @@ def completion(self):
def recompilation_count(self):
return self._recompilation_count

@property
def history(self):
return self._history

def check_args(self, args, kwargs, variables, variableMap):
for var, arg in zip(variables, args):
if var in kwargs:
Expand Down Expand Up @@ -95,13 +110,14 @@ def compile(self, test_examples: ExampleType = []):
self.training_examples,
)
# print("Prompt:", prompt)
code, self._recompilation_count = implement_body(
code, self._recompilation_count, retry_count = implement_body(
function_name, prompt, test_examples
)
os.makedirs(module_path, exist_ok=True)
with open(module_file_path, "w") as f:
f.write(
"# Recompilation count: " + str(self._recompilation_count) + "\n"
"# Retry count: " + str(retry_count) + "\n"
)
f.write(code)
else:
Expand Down
1 change: 1 addition & 0 deletions pyaskit/py_askit.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@ def defun(
param_types: Dict[str, ParamType],
template: str,
training_examples: ExampleType = [],
validate=None,
):
return Function(return_type, param_types, template, training_examples)

Expand Down

0 comments on commit ee5ebad

Please sign in to comment.