From 14c0bf1e05b2b2000b32749723b7d3f9f23898b0 Mon Sep 17 00:00:00 2001 From: Eric Anderson Date: Tue, 26 Nov 2024 13:56:26 -0800 Subject: [PATCH] lintfix --- lib/sycamore/sycamore/docset.py | 6 +++- lib/sycamore/sycamore/query/operators/sort.py | 3 +- .../query/execution/test_sycamore_operator.py | 29 ++++++++++--------- .../tests/unit/transforms/test_sort.py | 12 +++----- lib/sycamore/sycamore/transforms/sort.py | 12 +++++--- 5 files changed, 35 insertions(+), 27 deletions(-) diff --git a/lib/sycamore/sycamore/docset.py b/lib/sycamore/sycamore/docset.py index b315e529c..93700f951 100644 --- a/lib/sycamore/sycamore/docset.py +++ b/lib/sycamore/sycamore/docset.py @@ -1195,7 +1195,11 @@ def sort(self, descending: bool, field: str, default_val: Optional[Any] = None) plan = self.plan if default_val is None: import logging - logging.warning("Default value is none. Adding explicit filter step to drop documents missing the key. This includes any metadata.documents.") + + logging.warning( + "Default value is none. Adding explicit filter step to drop documents missing the key." + " This includes any metadata.documents." + ) plan = DropIfMissingField(plan, field) return DocSet(self.context, Sort(plan, descending, field, default_val)) diff --git a/lib/sycamore/sycamore/query/operators/sort.py b/lib/sycamore/sycamore/query/operators/sort.py index 048a517cf..5e23b88b0 100644 --- a/lib/sycamore/sycamore/query/operators/sort.py +++ b/lib/sycamore/sycamore/query/operators/sort.py @@ -15,4 +15,5 @@ class Sort(Node): field: str """The name of the database field to sort based on.""" - default_value: Any + default_value: Any = None + """The default value used when sorting if a document is missing the specified field.""" diff --git a/lib/sycamore/sycamore/tests/unit/query/execution/test_sycamore_operator.py b/lib/sycamore/sycamore/tests/unit/query/execution/test_sycamore_operator.py index c00edf630..f0051f9d1 100644 --- a/lib/sycamore/sycamore/tests/unit/query/execution/test_sycamore_operator.py +++ b/lib/sycamore/sycamore/tests/unit/query/execution/test_sycamore_operator.py @@ -331,20 +331,23 @@ def test_llm_extract_entity(): def test_sort(): - context = sycamore.init() - doc_set = Mock(spec=DocSet) - return_doc_set = Mock(spec=DocSet) - doc_set.sort.return_value = return_doc_set - logical_node = Sort(node_id=0, descending=True, field="properties.counter", default_value=0) - sycamore_operator = SycamoreSort(context, logical_node, query_id="test", inputs=[doc_set]) - result = sycamore_operator.execute() + Sort(node_id=0, descending=True, field="no-default-value") - doc_set.sort.assert_called_once_with( - descending=logical_node.descending, - field=logical_node.field, - default_val=logical_node.default_value, - ) - assert result == return_doc_set + for default_value in [None, 0]: + context = sycamore.init() + doc_set = Mock(spec=DocSet) + return_doc_set = Mock(spec=DocSet) + doc_set.sort.return_value = return_doc_set + logical_node = Sort(node_id=0, descending=True, field="properties.counter", default_value=default_value) + sycamore_operator = SycamoreSort(context, logical_node, query_id="test", inputs=[doc_set]) + result = sycamore_operator.execute() + + doc_set.sort.assert_called_once_with( + descending=logical_node.descending, + field=logical_node.field, + default_val=logical_node.default_value, + ) + assert result == return_doc_set def test_top_k(): diff --git a/lib/sycamore/sycamore/tests/unit/transforms/test_sort.py b/lib/sycamore/sycamore/tests/unit/transforms/test_sort.py index f9e039a96..66268657c 100644 --- a/lib/sycamore/sycamore/tests/unit/transforms/test_sort.py +++ b/lib/sycamore/sycamore/tests/unit/transforms/test_sort.py @@ -1,5 +1,4 @@ import string -import pytest import random import unittest @@ -14,7 +13,7 @@ class TestSort(unittest.TestCase): def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) self.exec_mode = ExecMode.LOCAL - + def docs(self) -> list[Document]: doc_list = [ # text_representation is random 6 letter strings @@ -56,7 +55,7 @@ def test_default_value(self): assert doc_list[i].properties.get("even", 0) >= doc_list[i - 1].properties.get("even", 0) assert len(doc_list) == self.NUM_DOCS / 2 - + sorted_docset = self.docset().sort(False, "properties.even", 0) doc_list = sorted_docset.take_all() @@ -70,7 +69,7 @@ def test_metadata_document(self): MetadataDocument(), Document(text_representation="C"), MetadataDocument(), - Document(text_representation=None) + Document(text_representation=None), ] context = sycamore.init(exec_mode=self.exec_mode) @@ -82,7 +81,6 @@ def test_metadata_document(self): assert sorted_doc_list[0].text_representation == "B" assert sorted_doc_list[1].text_representation == "C" assert sorted_doc_list[2].text_representation == "Z" - sorted_docset = docset.sort(False, "text_representation", "A") sorted_doc_list = sorted_docset.take_all(include_metadata=True) @@ -102,7 +100,5 @@ def test_metadata_document(self): assert sorted_doc_list[1].text_representation == "C" assert sorted_doc_list[2].text_representation == "B" for i in range(3): - d = sorted_doc_list[i+3] + d = sorted_doc_list[i + 3] assert isinstance(d, MetadataDocument) or d.text_representation is None - - diff --git a/lib/sycamore/sycamore/transforms/sort.py b/lib/sycamore/sycamore/transforms/sort.py index 9719d8a82..1a4cf7d2b 100644 --- a/lib/sycamore/sycamore/transforms/sort.py +++ b/lib/sycamore/sycamore/transforms/sort.py @@ -1,21 +1,24 @@ from typing import Any, Optional, TYPE_CHECKING from sycamore.plan_nodes import Node, Transform -from sycamore.data import Document, MetadataDocument +from sycamore.data import Document if TYPE_CHECKING: from ray.data import Dataset from sycamore.plan_nodes import UnaryNode + + class DropIfMissingField(UnaryNode): """Drop all documents that are missing the specified field, including metadata. This makes ray work because - ray requires a key value that is comparable between all entries which means None can't be used and there - is no easy way to auto-infer the correct type.""" + ray requires a key value that is comparable between all entries which means None can't be used and there + is no easy way to auto-infer the correct type.""" + def __init__(self, child, field): super().__init__(child) self._field = field - def execute(self) -> "Dataset": + def execute(self, **kwargs) -> "Dataset": input_dataset = self.child().execute() result = input_dataset.map_batches(self.ray_drop_documents) return result @@ -28,6 +31,7 @@ def ray_drop_documents(self, ray_input): out_docs = self.local_execute(all_docs) return {"doc": [d.serialize() for d in out_docs]} + class Sort(Transform): """ Sort by field in Document