Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add object converter #1113

Merged
merged 1 commit into from
Feb 18, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
46 changes: 42 additions & 4 deletions agixt/extensions/agixt_actions.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,8 @@
import requests
import os
import re
from typing import List
from typing import List, Type
from pydantic import BaseModel
from Extensions import Extensions
from local_llm import LLM
from ApiClient import Chain
Expand Down Expand Up @@ -165,6 +166,7 @@ def __init__(self, **kwargs):
"Get CSV Preview": self.get_csv_preview,
"Get CSV Preview Text": self.get_csv_preview_text,
"Strip CSV Data from Code Block": self.get_csv_from_response,
"Convert a string to a Pydantic model": self.convert_string_to_pydantic_model,
}

for chain in chains:
Expand All @@ -179,6 +181,7 @@ def __init__(self, **kwargs):
self.WORKING_DIRECTORY = os.path.join(os.getcwd(), "WORKSPACE")
os.makedirs(self.WORKING_DIRECTORY, exist_ok=True)
self.ApiClient = kwargs["ApiClient"] if "ApiClient" in kwargs else None
self.failures = 0

async def models(self):
return LLM().models()
Expand Down Expand Up @@ -352,9 +355,11 @@ def resolve_schema(ref):
"in": param.get("in", ""),
"description": param.get("description", ""),
"required": param.get("required", False),
"type": param.get("schema", {}).get("type", "")
if "schema" in param
else "",
"type": (
param.get("schema", {}).get("type", "")
if "schema" in param
else ""
),
}
endpoint_info["parameters"].append(param_info)
if "requestBody" in method_info:
Expand Down Expand Up @@ -665,3 +670,36 @@ async def convert_questions_to_dataset(self, response):
)
)
tasks.append(task)

async def convert_string_to_pydantic_model(
self, input_string: str, output_model: Type[BaseModel]
):
fields = output_model.model_fields
field_descriptions = [f"{field}: {fields[field]}" for field in fields]
schema = "\n".join(field_descriptions)
response = self.ApiClient.prompt_agent(
agent_name=self.agent_name,
prompt_name="Convert to JSON",
prompt_args={
"user_input": input_string,
"schema": schema,
"conversation_name": "AGiXT Terminal",
},
)
response = str(response).split("```json")[1].split("```")[0].strip()
try:
response = json.loads(response)
return output_model(**response)
except:
self.failures += 1
logging.warning(f"Failed to convert response, the response was: {response}")
logging.info(f"[{self.failures}/3] Retrying conversion")
if self.failures < 3:
return await self.convert_string_to_pydantic_model(
input_string=input_string, output_model=output_model
)
else:
logging.error(
"Failed to convert response after 3 attempts, returning empty string."
)
return ""
11 changes: 11 additions & 0 deletions agixt/prompts/Default/Convert to JSON.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
**Act as a JSON converter that converts any text into the desired JSON format based on the schema provided. Respond only with JSON in a properly formatted markdown code block, no explanations.**

**Reformat the following information into a structured format according to the schema provided:**

## Information:
{user_input}

## Schema:
{schema}

JSON Structured Output:
Loading