Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add mypy warn_unreachable #623

Merged
merged 1 commit into from
Jul 30, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion libs/colbert/pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,8 @@ asyncio_mode = "auto"

[tool.mypy]
strict = true
follow_imports = "normal"
warn_unreachable = true
pretty = true
show_error_codes = true
show_error_context = true

Expand Down
65 changes: 34 additions & 31 deletions libs/colbert/tests/integration_tests/test_database.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,9 +32,7 @@ def test_database_sync(session: Session) -> None:

results = database.add_chunks(chunks=[chunk_0, chunk_1])

assert len(results) == 2 # noqa: PLR2004
assert results[0] == (doc_id, 0)
assert results[1] == (doc_id, 1)
assert results == [(doc_id, 0), (doc_id, 1)]

# TODO: verify other db methods.

Expand Down Expand Up @@ -71,46 +69,51 @@ async def test_database_async(session: Session) -> None:
)

results = await database.aadd_chunks(chunks=[chunk_0, chunk_1])
assert len(results) == 2 # noqa: PLR2004
assert results[0] == (doc_id, 0)
assert results[1] == (doc_id, 1)
assert results == [(doc_id, 0), (doc_id, 1)]

chunks = await database.search_relevant_chunks(
vector=climate_change_embedding[5], n=2
)
assert len(chunks) == 1
assert chunks[0].doc_id == doc_id
assert chunks[0].chunk_id == 0
assert chunks[0].text is None
assert chunks[0].metadata == {}
assert chunks[0].embedding is None
assert chunks == [
Chunk(
doc_id=doc_id,
chunk_id=0,
embedding=None,
)
]

chunk = await database.get_chunk_embedding(doc_id=doc_id, chunk_id=1)
assert chunk.doc_id == doc_id
assert chunk.chunk_id == 1
assert chunk.text is None
assert chunk.metadata == {}
assert chunk.embedding == chunk_1.embedding
assert chunk == Chunk(
doc_id=doc_id,
chunk_id=1,
embedding=chunk_1.embedding,
)

chunk = await database.get_chunk_data(doc_id=doc_id, chunk_id=0)
assert chunk.doc_id == doc_id
assert chunk.chunk_id == 0
assert chunk.text == chunk_0.text
# this is broken due to a cassio bug
# which converts Number fields to strings
# assert chunk.metadata == chunk_0.metadata
assert chunk.embedding is None

assert chunk == Chunk(
doc_id=doc_id,
chunk_id=0,
text=chunk_0.text,
# this is broken due to a cassio bug
# which converts Number fields to strings
# metadata=chunk_0.metadata,
embedding=None,
)

chunk = await database.get_chunk_data(
doc_id=doc_id, chunk_id=0, include_embedding=True
)
assert chunk.doc_id == doc_id
assert chunk.chunk_id == 0
assert chunk.text == chunk_0.text
# this is broken due to a cassio bug
# which converts Number fields to strings
# assert chunk.metadata == chunk_0.metadata
assert chunk.embedding == chunk_0.embedding

assert chunk == Chunk(
doc_id=doc_id,
chunk_id=0,
text=chunk_0.text,
# this is broken due to a cassio bug
# which converts Number fields to strings
# metadata=chunk_0.metadata,
embedding=chunk_0.embedding,
)

result = await database.adelete_chunks(doc_ids=[doc_id])
assert result
3 changes: 2 additions & 1 deletion libs/knowledge-graph/pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,8 @@ build-backend = "poetry.core.masonry.api"

[tool.mypy]
strict = true
follow_imports = "normal"
warn_unreachable = true
pretty = true
show_error_codes = true
show_error_context = true

Expand Down
20 changes: 9 additions & 11 deletions libs/knowledge-graph/ragstack_knowledge_graph/knowledge_schema.py
Original file line number Diff line number Diff line change
Expand Up @@ -109,18 +109,16 @@ def validate_graph_document(self, document: GraphDocument) -> None:
relationships = self._relationships.get(r.type, None)
if relationships is None:
e.add_note(f"No edge type '{r.type}")
else:
relationship = next(
candidate
for candidate in relationships
if r.source.type in candidate.source_types
if r.target.type in candidate.target_types
elif not any(
candidate
for candidate in relationships
if r.source.type in candidate.source_types
and r.target.type in candidate.target_types
):
e.add_note(
"No relationship allows "
f"({r.source.id} -> {r.type} -> {r.target.type})"
)
if relationship is None:
e.add_note(
"No relationship allows "
f"({r.source.id} -> {r.type} -> {r.target.type})"
)

if e.__notes__:
raise e
3 changes: 2 additions & 1 deletion libs/knowledge-store/pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,8 @@ build-backend = "poetry.core.masonry.api"

[tool.mypy]
strict = true
follow_imports = "normal"
warn_unreachable = true
pretty = true
show_error_codes = true
show_error_context = true

Expand Down
27 changes: 10 additions & 17 deletions libs/knowledge-store/ragstack_knowledge_store/graph_store.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,17 +4,13 @@
import logging
import re
import secrets
from collections.abc import Iterable
from dataclasses import asdict, dataclass, field, is_dataclass
from enum import Enum
from typing import (
TYPE_CHECKING,
Any,
Dict,
Iterable,
List,
Sequence,
Set,
Tuple,
Union,
cast,
)
Expand Down Expand Up @@ -70,8 +66,8 @@ class MetadataIndexingMode(Enum):
DEFAULT_TO_SEARCHABLE = 2


MetadataIndexingType = Union[Tuple[str, Iterable[str]], str]
MetadataIndexingPolicy = Tuple[MetadataIndexingMode, Set[str]]
MetadataIndexingType = Union[tuple[str, Iterable[str]], str]
MetadataIndexingPolicy = tuple[MetadataIndexingMode, set[str]]


def _is_metadata_field_indexed(field_name: str, policy: MetadataIndexingPolicy) -> bool:
Expand All @@ -84,7 +80,7 @@ def _is_metadata_field_indexed(field_name: str, policy: MetadataIndexingPolicy)


def _serialize_metadata(md: dict[str, Any]) -> str:
if isinstance(md.get("links"), Set):
if isinstance(md.get("links"), set):
md = md.copy()
md["links"] = list(md["links"])
return json.dumps(md)
Expand All @@ -93,15 +89,12 @@ def _serialize_metadata(md: dict[str, Any]) -> str:
def _serialize_links(links: set[Link]) -> str:
class SetAndLinkEncoder(json.JSONEncoder):
def default(self, obj: Any) -> Any:
if is_dataclass(obj) and not isinstance(obj, type):
if not isinstance(obj, type) and is_dataclass(obj):
return asdict(obj)

try:
iterable = iter(obj)
except TypeError:
pass
else:
return list(iterable)
if isinstance(obj, Iterable):
return list(obj)

# Let the base class default method raise the TypeError
return super().default(obj)

Expand All @@ -111,13 +104,13 @@ def default(self, obj: Any) -> Any:
def _deserialize_metadata(json_blob: str | None) -> dict[str, Any]:
# We don't need to convert the links list back to a set -- it will be
# converted when accessed, if needed.
return cast(Dict[str, Any], json.loads(json_blob or ""))
return cast(dict[str, Any], json.loads(json_blob or ""))


def _deserialize_links(json_blob: str | None) -> set[Link]:
return {
Link(kind=link["kind"], direction=link["direction"], tag=link["tag"])
for link in cast(List[Dict[str, Any]], json.loads(json_blob or ""))
for link in cast(list[dict[str, Any]], json.loads(json_blob or ""))
}


Expand Down
3 changes: 2 additions & 1 deletion libs/langchain/pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,8 @@ mypy = "^1.11.0"

[tool.mypy]
strict = true
follow_imports = "normal"
warn_unreachable = true
pretty = true
show_error_codes = true
show_error_context = true

Expand Down
3 changes: 2 additions & 1 deletion libs/llamaindex/pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,8 @@ ragstack-ai-colbert = { path = "../colbert", develop = true }

[tool.mypy]
strict = true
follow_imports = "normal"
warn_unreachable = true
pretty = true
show_error_codes = true
show_error_context = true

Expand Down
3 changes: 2 additions & 1 deletion libs/ragulate/pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,8 @@ test_integration = "scripts.test_integration_runner:main"

[tool.mypy]
strict = true
follow_imports = "normal"
warn_unreachable = true
pretty = true
show_error_codes = true
show_error_context = true

Expand Down
3 changes: 2 additions & 1 deletion libs/tests-utils/pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,8 @@ mypy = "^1.10.0"

[tool.mypy]
strict = true
follow_imports = "normal"
warn_unreachable = true
pretty = true
show_error_codes = true
show_error_context = true

Expand Down
Loading