diff --git a/libs/colbert/pyproject.toml b/libs/colbert/pyproject.toml index 763f4d984..e2277483a 100644 --- a/libs/colbert/pyproject.toml +++ b/libs/colbert/pyproject.toml @@ -33,7 +33,8 @@ asyncio_mode = "auto" [tool.mypy] strict = true -follow_imports = "normal" +warn_unreachable = true +pretty = true show_error_codes = true show_error_context = true diff --git a/libs/colbert/tests/integration_tests/test_database.py b/libs/colbert/tests/integration_tests/test_database.py index 50a4150ae..60eeefaff 100644 --- a/libs/colbert/tests/integration_tests/test_database.py +++ b/libs/colbert/tests/integration_tests/test_database.py @@ -32,9 +32,7 @@ def test_database_sync(session: Session) -> None: results = database.add_chunks(chunks=[chunk_0, chunk_1]) - assert len(results) == 2 # noqa: PLR2004 - assert results[0] == (doc_id, 0) - assert results[1] == (doc_id, 1) + assert results == [(doc_id, 0), (doc_id, 1)] # TODO: verify other db methods. @@ -71,46 +69,51 @@ async def test_database_async(session: Session) -> None: ) results = await database.aadd_chunks(chunks=[chunk_0, chunk_1]) - assert len(results) == 2 # noqa: PLR2004 - assert results[0] == (doc_id, 0) - assert results[1] == (doc_id, 1) + assert results == [(doc_id, 0), (doc_id, 1)] chunks = await database.search_relevant_chunks( vector=climate_change_embedding[5], n=2 ) - assert len(chunks) == 1 - assert chunks[0].doc_id == doc_id - assert chunks[0].chunk_id == 0 - assert chunks[0].text is None - assert chunks[0].metadata == {} - assert chunks[0].embedding is None + assert chunks == [ + Chunk( + doc_id=doc_id, + chunk_id=0, + embedding=None, + ) + ] chunk = await database.get_chunk_embedding(doc_id=doc_id, chunk_id=1) - assert chunk.doc_id == doc_id - assert chunk.chunk_id == 1 - assert chunk.text is None - assert chunk.metadata == {} - assert chunk.embedding == chunk_1.embedding + assert chunk == Chunk( + doc_id=doc_id, + chunk_id=1, + embedding=chunk_1.embedding, + ) chunk = await database.get_chunk_data(doc_id=doc_id, chunk_id=0) - assert chunk.doc_id == doc_id - assert chunk.chunk_id == 0 - assert chunk.text == chunk_0.text - # this is broken due to a cassio bug - # which converts Number fields to strings - # assert chunk.metadata == chunk_0.metadata - assert chunk.embedding is None + + assert chunk == Chunk( + doc_id=doc_id, + chunk_id=0, + text=chunk_0.text, + # this is broken due to a cassio bug + # which converts Number fields to strings + # metadata=chunk_0.metadata, + embedding=None, + ) chunk = await database.get_chunk_data( doc_id=doc_id, chunk_id=0, include_embedding=True ) - assert chunk.doc_id == doc_id - assert chunk.chunk_id == 0 - assert chunk.text == chunk_0.text - # this is broken due to a cassio bug - # which converts Number fields to strings - # assert chunk.metadata == chunk_0.metadata - assert chunk.embedding == chunk_0.embedding + + assert chunk == Chunk( + doc_id=doc_id, + chunk_id=0, + text=chunk_0.text, + # this is broken due to a cassio bug + # which converts Number fields to strings + # metadata=chunk_0.metadata, + embedding=chunk_0.embedding, + ) result = await database.adelete_chunks(doc_ids=[doc_id]) assert result diff --git a/libs/knowledge-graph/pyproject.toml b/libs/knowledge-graph/pyproject.toml index 4104f2845..04fee1470 100644 --- a/libs/knowledge-graph/pyproject.toml +++ b/libs/knowledge-graph/pyproject.toml @@ -44,7 +44,8 @@ build-backend = "poetry.core.masonry.api" [tool.mypy] strict = true -follow_imports = "normal" +warn_unreachable = true +pretty = true show_error_codes = true show_error_context = true diff --git a/libs/knowledge-graph/ragstack_knowledge_graph/knowledge_schema.py b/libs/knowledge-graph/ragstack_knowledge_graph/knowledge_schema.py index 17f4a55fc..e553694d5 100644 --- a/libs/knowledge-graph/ragstack_knowledge_graph/knowledge_schema.py +++ b/libs/knowledge-graph/ragstack_knowledge_graph/knowledge_schema.py @@ -109,18 +109,16 @@ def validate_graph_document(self, document: GraphDocument) -> None: relationships = self._relationships.get(r.type, None) if relationships is None: e.add_note(f"No edge type '{r.type}") - else: - relationship = next( - candidate - for candidate in relationships - if r.source.type in candidate.source_types - if r.target.type in candidate.target_types + elif not any( + candidate + for candidate in relationships + if r.source.type in candidate.source_types + and r.target.type in candidate.target_types + ): + e.add_note( + "No relationship allows " + f"({r.source.id} -> {r.type} -> {r.target.type})" ) - if relationship is None: - e.add_note( - "No relationship allows " - f"({r.source.id} -> {r.type} -> {r.target.type})" - ) if e.__notes__: raise e diff --git a/libs/knowledge-store/pyproject.toml b/libs/knowledge-store/pyproject.toml index 5e24a3ca0..a98f32bc2 100644 --- a/libs/knowledge-store/pyproject.toml +++ b/libs/knowledge-store/pyproject.toml @@ -38,7 +38,8 @@ build-backend = "poetry.core.masonry.api" [tool.mypy] strict = true -follow_imports = "normal" +warn_unreachable = true +pretty = true show_error_codes = true show_error_context = true diff --git a/libs/knowledge-store/ragstack_knowledge_store/graph_store.py b/libs/knowledge-store/ragstack_knowledge_store/graph_store.py index c8ed3c540..e0338ea9b 100644 --- a/libs/knowledge-store/ragstack_knowledge_store/graph_store.py +++ b/libs/knowledge-store/ragstack_knowledge_store/graph_store.py @@ -4,17 +4,13 @@ import logging import re import secrets +from collections.abc import Iterable from dataclasses import asdict, dataclass, field, is_dataclass from enum import Enum from typing import ( TYPE_CHECKING, Any, - Dict, - Iterable, - List, Sequence, - Set, - Tuple, Union, cast, ) @@ -70,8 +66,8 @@ class MetadataIndexingMode(Enum): DEFAULT_TO_SEARCHABLE = 2 -MetadataIndexingType = Union[Tuple[str, Iterable[str]], str] -MetadataIndexingPolicy = Tuple[MetadataIndexingMode, Set[str]] +MetadataIndexingType = Union[tuple[str, Iterable[str]], str] +MetadataIndexingPolicy = tuple[MetadataIndexingMode, set[str]] def _is_metadata_field_indexed(field_name: str, policy: MetadataIndexingPolicy) -> bool: @@ -84,7 +80,7 @@ def _is_metadata_field_indexed(field_name: str, policy: MetadataIndexingPolicy) def _serialize_metadata(md: dict[str, Any]) -> str: - if isinstance(md.get("links"), Set): + if isinstance(md.get("links"), set): md = md.copy() md["links"] = list(md["links"]) return json.dumps(md) @@ -93,15 +89,12 @@ def _serialize_metadata(md: dict[str, Any]) -> str: def _serialize_links(links: set[Link]) -> str: class SetAndLinkEncoder(json.JSONEncoder): def default(self, obj: Any) -> Any: - if is_dataclass(obj) and not isinstance(obj, type): + if not isinstance(obj, type) and is_dataclass(obj): return asdict(obj) - try: - iterable = iter(obj) - except TypeError: - pass - else: - return list(iterable) + if isinstance(obj, Iterable): + return list(obj) + # Let the base class default method raise the TypeError return super().default(obj) @@ -111,13 +104,13 @@ def default(self, obj: Any) -> Any: def _deserialize_metadata(json_blob: str | None) -> dict[str, Any]: # We don't need to convert the links list back to a set -- it will be # converted when accessed, if needed. - return cast(Dict[str, Any], json.loads(json_blob or "")) + return cast(dict[str, Any], json.loads(json_blob or "")) def _deserialize_links(json_blob: str | None) -> set[Link]: return { Link(kind=link["kind"], direction=link["direction"], tag=link["tag"]) - for link in cast(List[Dict[str, Any]], json.loads(json_blob or "")) + for link in cast(list[dict[str, Any]], json.loads(json_blob or "")) } diff --git a/libs/langchain/pyproject.toml b/libs/langchain/pyproject.toml index e65eba610..0361cc3c1 100644 --- a/libs/langchain/pyproject.toml +++ b/libs/langchain/pyproject.toml @@ -45,7 +45,8 @@ mypy = "^1.11.0" [tool.mypy] strict = true -follow_imports = "normal" +warn_unreachable = true +pretty = true show_error_codes = true show_error_context = true diff --git a/libs/llamaindex/pyproject.toml b/libs/llamaindex/pyproject.toml index 116c80b48..2aa24d493 100644 --- a/libs/llamaindex/pyproject.toml +++ b/libs/llamaindex/pyproject.toml @@ -53,7 +53,8 @@ ragstack-ai-colbert = { path = "../colbert", develop = true } [tool.mypy] strict = true -follow_imports = "normal" +warn_unreachable = true +pretty = true show_error_codes = true show_error_context = true diff --git a/libs/ragulate/pyproject.toml b/libs/ragulate/pyproject.toml index f9664cea4..891f67701 100644 --- a/libs/ragulate/pyproject.toml +++ b/libs/ragulate/pyproject.toml @@ -56,7 +56,8 @@ test_integration = "scripts.test_integration_runner:main" [tool.mypy] strict = true -follow_imports = "normal" +warn_unreachable = true +pretty = true show_error_codes = true show_error_context = true diff --git a/libs/tests-utils/pyproject.toml b/libs/tests-utils/pyproject.toml index e3013fb05..c61fe067d 100644 --- a/libs/tests-utils/pyproject.toml +++ b/libs/tests-utils/pyproject.toml @@ -21,7 +21,8 @@ mypy = "^1.10.0" [tool.mypy] strict = true -follow_imports = "normal" +warn_unreachable = true +pretty = true show_error_codes = true show_error_context = true