Skip to content

Commit

Permalink
Test fix
Browse files Browse the repository at this point in the history
  • Loading branch information
baitsguy committed Dec 20, 2024
1 parent 0c56bfc commit 38c469e
Show file tree
Hide file tree
Showing 3 changed files with 31 additions and 8 deletions.
4 changes: 2 additions & 2 deletions lib/sycamore/sycamore/llms/prompts/default_prompts.py
Original file line number Diff line number Diff line change
Expand Up @@ -194,12 +194,12 @@ def _format_schema(schema: Schema) -> str:


class PropertiesZeroShotGuidancePrompt(SimplePrompt):
def __init__(self, entity: str, properties: Any, text: str):
def __init__(self):
super().__init__()

self.system = "You are a helpful property extractor. You only return JSON."

self.user = f"""You are given a few text elements of a document. Extract JSON representing one entity of
self.user = """You are given a few text elements of a document. Extract JSON representing one entity of
class {entity} from the document. The class only has properties {properties}. Using
this context, FIND, FORMAT, and RETURN the JSON representing one {entity}.
Only return JSON as part of your answer. If no entity is in the text, return "None".
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
from sycamore.transforms.extract_schema import LLMPropertyExtractor


def test_extract_properties_from_schema():
def get_docs():
docs = [
Document(
{
Expand All @@ -23,6 +23,27 @@ def test_extract_properties_from_schema():
}
),
]
return docs


def test_extract_properties_from_dict_schema():
docs = get_docs()[:1] # only validate first doc because of technique reliability
schema = {"name": "str", "age": "int", "date": "str", "from_location": "str"}
property_extractor = LLMPropertyExtractor(OpenAI(OpenAIModels.GPT_4O), schema=schema, schema_name="entity")

ctx = sycamore.init(exec_mode=ExecMode.LOCAL)
docs = ctx.read.document(docs)
docs = docs.extract_properties(property_extractor)

taken = docs.take_all()

assert taken[0].properties["entity"]["name"] == "Vinayak"
assert taken[0].properties["entity"]["age"] == 74
assert "Honolulu" in taken[0].properties["entity"]["from_location"]


def test_extract_properties_from_schema():
docs = get_docs()

schema = Schema(
fields=[
Expand Down
12 changes: 7 additions & 5 deletions lib/sycamore/sycamore/transforms/extract_schema.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from abc import ABC, abstractmethod
from typing import Callable, Any, Optional, Tuple
from typing import Callable, Any, Optional, Tuple, Union
import json

from sycamore.data import Element, Document
Expand Down Expand Up @@ -156,7 +156,7 @@ def __init__(
self,
llm: LLM,
schema_name: Optional[str] = None,
schema: Optional[Tuple[dict[str, str], Schema]] = None,
schema: Optional[Union[dict[str, str], Schema]] = None,
num_of_elements: int = 10,
prompt_formatter: Callable[[list[Element]], str] = element_list_formatter,
):
Expand Down Expand Up @@ -192,19 +192,21 @@ def _handle_zero_shot_prompting(self, document: Document) -> Any:
text = self._prompt_formatter(
[document.elements[i] for i in range((min(self._num_of_elements, len(document.elements))))]
)

if isinstance(self._schema, Schema):
prompt = ExtractPropertiesFromSchemaPrompt(schema=self._schema, text=text)
entities = self._llm.generate(prompt_kwargs={"prompt": prompt})
else:
schema = self._schema or document.properties.get("_schema")
assert schema is not None, "Schema must be provided or detected before extracting properties."

schema_name = self._schema_name or document.properties.get("_schema_class")
assert schema_name is not None, "Schema name must be provided or detected before extracting properties."

prompt = PropertiesZeroShotGuidancePrompt(entity=schema_name, properties=schema, text=text)
prompt = PropertiesZeroShotGuidancePrompt()

entities = self._llm.generate(prompt_kwargs={"prompt": prompt})
entities = self._llm.generate(
prompt_kwargs={"prompt": prompt, "entity": schema_name, "properties": schema, "text": text}
)
return entities


Expand Down

0 comments on commit 38c469e

Please sign in to comment.