Skip to content

Commit

Permalink
lintfix
Browse files Browse the repository at this point in the history
  • Loading branch information
eric-anderson committed Nov 26, 2024
1 parent a0049b4 commit 14c0bf1
Show file tree
Hide file tree
Showing 5 changed files with 35 additions and 27 deletions.
6 changes: 5 additions & 1 deletion lib/sycamore/sycamore/docset.py
Original file line number Diff line number Diff line change
Expand Up @@ -1195,7 +1195,11 @@ def sort(self, descending: bool, field: str, default_val: Optional[Any] = None)
plan = self.plan
if default_val is None:
import logging
logging.warning("Default value is none. Adding explicit filter step to drop documents missing the key. This includes any metadata.documents.")

logging.warning(
"Default value is none. Adding explicit filter step to drop documents missing the key."
" This includes any metadata.documents."
)
plan = DropIfMissingField(plan, field)
return DocSet(self.context, Sort(plan, descending, field, default_val))

Expand Down
3 changes: 2 additions & 1 deletion lib/sycamore/sycamore/query/operators/sort.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,4 +15,5 @@ class Sort(Node):
field: str
"""The name of the database field to sort based on."""

default_value: Any
default_value: Any = None
"""The default value used when sorting if a document is missing the specified field."""
Original file line number Diff line number Diff line change
Expand Up @@ -331,20 +331,23 @@ def test_llm_extract_entity():


def test_sort():
context = sycamore.init()
doc_set = Mock(spec=DocSet)
return_doc_set = Mock(spec=DocSet)
doc_set.sort.return_value = return_doc_set
logical_node = Sort(node_id=0, descending=True, field="properties.counter", default_value=0)
sycamore_operator = SycamoreSort(context, logical_node, query_id="test", inputs=[doc_set])
result = sycamore_operator.execute()
Sort(node_id=0, descending=True, field="no-default-value")

doc_set.sort.assert_called_once_with(
descending=logical_node.descending,
field=logical_node.field,
default_val=logical_node.default_value,
)
assert result == return_doc_set
for default_value in [None, 0]:
context = sycamore.init()
doc_set = Mock(spec=DocSet)
return_doc_set = Mock(spec=DocSet)
doc_set.sort.return_value = return_doc_set
logical_node = Sort(node_id=0, descending=True, field="properties.counter", default_value=default_value)
sycamore_operator = SycamoreSort(context, logical_node, query_id="test", inputs=[doc_set])
result = sycamore_operator.execute()

doc_set.sort.assert_called_once_with(
descending=logical_node.descending,
field=logical_node.field,
default_val=logical_node.default_value,
)
assert result == return_doc_set


def test_top_k():
Expand Down
12 changes: 4 additions & 8 deletions lib/sycamore/sycamore/tests/unit/transforms/test_sort.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
import string
import pytest
import random
import unittest

Expand All @@ -14,7 +13,7 @@ class TestSort(unittest.TestCase):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self.exec_mode = ExecMode.LOCAL

def docs(self) -> list[Document]:
doc_list = [
# text_representation is random 6 letter strings
Expand Down Expand Up @@ -56,7 +55,7 @@ def test_default_value(self):
assert doc_list[i].properties.get("even", 0) >= doc_list[i - 1].properties.get("even", 0)

assert len(doc_list) == self.NUM_DOCS / 2

sorted_docset = self.docset().sort(False, "properties.even", 0)
doc_list = sorted_docset.take_all()

Expand All @@ -70,7 +69,7 @@ def test_metadata_document(self):
MetadataDocument(),
Document(text_representation="C"),
MetadataDocument(),
Document(text_representation=None)
Document(text_representation=None),
]

context = sycamore.init(exec_mode=self.exec_mode)
Expand All @@ -82,7 +81,6 @@ def test_metadata_document(self):
assert sorted_doc_list[0].text_representation == "B"
assert sorted_doc_list[1].text_representation == "C"
assert sorted_doc_list[2].text_representation == "Z"


sorted_docset = docset.sort(False, "text_representation", "A")
sorted_doc_list = sorted_docset.take_all(include_metadata=True)
Expand All @@ -102,7 +100,5 @@ def test_metadata_document(self):
assert sorted_doc_list[1].text_representation == "C"
assert sorted_doc_list[2].text_representation == "B"
for i in range(3):
d = sorted_doc_list[i+3]
d = sorted_doc_list[i + 3]
assert isinstance(d, MetadataDocument) or d.text_representation is None


12 changes: 8 additions & 4 deletions lib/sycamore/sycamore/transforms/sort.py
Original file line number Diff line number Diff line change
@@ -1,21 +1,24 @@
from typing import Any, Optional, TYPE_CHECKING

from sycamore.plan_nodes import Node, Transform
from sycamore.data import Document, MetadataDocument
from sycamore.data import Document

if TYPE_CHECKING:
from ray.data import Dataset

from sycamore.plan_nodes import UnaryNode


class DropIfMissingField(UnaryNode):
"""Drop all documents that are missing the specified field, including metadata. This makes ray work because
ray requires a key value that is comparable between all entries which means None can't be used and there
is no easy way to auto-infer the correct type."""
ray requires a key value that is comparable between all entries which means None can't be used and there
is no easy way to auto-infer the correct type."""

def __init__(self, child, field):
super().__init__(child)
self._field = field

def execute(self) -> "Dataset":
def execute(self, **kwargs) -> "Dataset":
input_dataset = self.child().execute()
result = input_dataset.map_batches(self.ray_drop_documents)
return result
Expand All @@ -28,6 +31,7 @@ def ray_drop_documents(self, ray_input):
out_docs = self.local_execute(all_docs)
return {"doc": [d.serialize() for d in out_docs]}


class Sort(Transform):
"""
Sort by field in Document
Expand Down

0 comments on commit 14c0bf1

Please sign in to comment.