Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

mypy: make quixstreams.models.* pass type checks #673

Merged
merged 3 commits into from
Dec 6, 2024
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -114,14 +114,14 @@ ignore_missing_imports = true
module = [
"quixstreams.sinks.community.*",
"quixstreams.sources.community.*",
"quixstreams.models.serializers.quix.*",
]
ignore_errors = true

[[tool.mypy.overrides]]
module = [
"quixstreams.core.*",
"quixstreams.dataframe.*",
"quixstreams.models.*",
"quixstreams.platforms.*",
"quixstreams.rowproducer.*"
]
Expand Down
6 changes: 3 additions & 3 deletions quixstreams/error_callbacks.py
Original file line number Diff line number Diff line change
@@ -1,18 +1,18 @@
import logging
from typing import Callable, Optional

from .models import ConfluentKafkaMessageProto, Row
from .models import RawConfluentKafkaMessageProto, Row

ProcessingErrorCallback = Callable[[Exception, Optional[Row], logging.Logger], bool]
ConsumerErrorCallback = Callable[
[Exception, Optional[ConfluentKafkaMessageProto], logging.Logger], bool
[Exception, Optional[RawConfluentKafkaMessageProto], logging.Logger], bool
]
ProducerErrorCallback = Callable[[Exception, Optional[Row], logging.Logger], bool]


def default_on_consumer_error(
exc: Exception,
message: Optional[ConfluentKafkaMessageProto],
message: Optional[RawConfluentKafkaMessageProto],
logger: logging.Logger,
):
topic, partition, offset = None, None, None
Expand Down
19 changes: 17 additions & 2 deletions quixstreams/kafka/consumer.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import functools
import logging
import typing
from typing import Any, Callable, List, Optional, Tuple, Union
from typing import Any, Callable, List, Optional, Tuple, Union, cast

from confluent_kafka import (
Consumer as ConfluentConsumer,
Expand All @@ -14,8 +14,13 @@
from confluent_kafka.admin import ClusterMetadata, GroupMetadata

from quixstreams.exceptions import KafkaPartitionError, PartitionAssignmentError
from quixstreams.models.types import (
RawConfluentKafkaMessageProto,
SuccessfulConfluentKafkaMessageProto,
)

from .configuration import ConnectionConfig
from .exceptions import KafkaConsumerException

__all__ = (
"BaseConsumer",
Expand Down Expand Up @@ -65,6 +70,14 @@ def wrapper(*args, **kwargs):
return wrapper


def raise_for_msg_error(
msg: RawConfluentKafkaMessageProto,
) -> SuccessfulConfluentKafkaMessageProto:
if msg.error():
raise KafkaConsumerException(error=msg.error())
return cast(SuccessfulConfluentKafkaMessageProto, msg)


class BaseConsumer:
def __init__(
self,
Expand Down Expand Up @@ -129,7 +142,9 @@ def __init__(
}
self._inner_consumer: Optional[ConfluentConsumer] = None

def poll(self, timeout: Optional[float] = None) -> Optional[Message]:
def poll(
self, timeout: Optional[float] = None
) -> Optional[RawConfluentKafkaMessageProto]:
"""
Consumes a single message, calls callbacks and returns events.

Expand Down
2 changes: 1 addition & 1 deletion quixstreams/models/rows.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ def __init__(
key: Optional[Any],
timestamp: int,
context: MessageContext,
headers: KafkaHeaders = None,
headers: KafkaHeaders,
):
self.value = value
self.key = key
Expand Down
34 changes: 21 additions & 13 deletions quixstreams/models/serializers/avro.py
Original file line number Diff line number Diff line change
Expand Up @@ -148,6 +148,12 @@ def __init__(
)

super().__init__()

if schema is None and schema_registry_client_config is None:
raise TypeError(
"One of `schema` or `schema_registry_client_config` is required"
)

self._schema = parse_schema(schema) if schema else None
self._reader_schema = parse_schema(reader_schema) if reader_schema else None
self._return_record_name = return_record_name
Expand All @@ -174,17 +180,19 @@ def __call__(
return self._schema_registry_deserializer(value, ctx)
except (SchemaRegistryError, _SerializationError, EOFError) as exc:
raise SerializationError(str(exc)) from exc
elif self._schema is not None:
try:
return schemaless_reader( # type: ignore
BytesIO(value),
self._schema,
reader_schema=self._reader_schema,
return_record_name=self._return_record_name,
return_record_name_override=self._return_record_name_override,
return_named_type=self._return_named_type,
return_named_type_override=self._return_named_type_override,
handle_unicode_errors=self._handle_unicode_errors,
)
except EOFError as exc:
raise SerializationError(str(exc)) from exc

try:
return schemaless_reader(
BytesIO(value),
self._schema,
reader_schema=self._reader_schema,
return_record_name=self._return_record_name,
return_record_name_override=self._return_record_name_override,
return_named_type=self._return_named_type,
return_named_type_override=self._return_named_type_override,
handle_unicode_errors=self._handle_unicode_errors,
)
except EOFError as exc:
raise SerializationError(str(exc)) from exc
raise SerializationError("no schema found")
6 changes: 3 additions & 3 deletions quixstreams/models/serializers/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
)
from typing_extensions import Literal, TypeAlias

from ..types import HeadersMapping, KafkaHeaders
from ..types import Headers, HeadersMapping, KafkaHeaders

__all__ = (
"SerializationContext",
Expand All @@ -33,8 +33,8 @@ class SerializationContext(_SerializationContext):
def __init__(
self,
topic: str,
field: MessageField,
headers: KafkaHeaders = None,
field: str,
headers: Union[KafkaHeaders, Headers] = None,
) -> None:
self.topic = topic
self.field = field
Expand Down
10 changes: 6 additions & 4 deletions quixstreams/models/serializers/json.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import json
from typing import Any, Callable, Iterable, Mapping, Optional, Union
from typing import TYPE_CHECKING, Any, Callable, Iterable, Mapping, Optional, Union

from confluent_kafka.schema_registry import SchemaRegistryClient, SchemaRegistryError
from confluent_kafka.schema_registry.json_schema import (
Expand All @@ -10,7 +10,6 @@
)
from confluent_kafka.serialization import SerializationError as _SerializationError
from jsonschema import Draft202012Validator, ValidationError
from jsonschema.protocols import Validator

from quixstreams.utils.json import (
dumps as default_dumps,
Expand All @@ -26,6 +25,9 @@
SchemaRegistrySerializationConfig,
)

if TYPE_CHECKING:
from jsonschema.validators import _Validator

__all__ = ("JSONSerializer", "JSONDeserializer")


Expand All @@ -34,7 +36,7 @@ def __init__(
self,
dumps: Callable[[Any], Union[str, bytes]] = default_dumps,
schema: Optional[Mapping] = None,
validator: Optional[Validator] = None,
validator: Optional["_Validator"] = None,
schema_registry_client_config: Optional[SchemaRegistryClientConfig] = None,
schema_registry_serialization_config: Optional[
SchemaRegistrySerializationConfig
Expand Down Expand Up @@ -121,7 +123,7 @@ def __init__(
self,
loads: Callable[[Union[bytes, bytearray]], Any] = default_loads,
schema: Optional[Mapping] = None,
validator: Optional[Validator] = None,
validator: Optional["_Validator"] = None,
schema_registry_client_config: Optional[SchemaRegistryClientConfig] = None,
):
"""
Expand Down
6 changes: 3 additions & 3 deletions quixstreams/models/serializers/protobuf.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import Dict, Iterable, Mapping, Optional, Union
from typing import Dict, Iterable, Mapping, Optional, Type, Union

from confluent_kafka.schema_registry import SchemaRegistryClient, SchemaRegistryError
from confluent_kafka.schema_registry.protobuf import (
Expand All @@ -24,7 +24,7 @@
class ProtobufSerializer(Serializer):
def __init__(
self,
msg_type: Message,
msg_type: Type[Message],
deterministic: bool = False,
ignore_unknown_fields: bool = False,
schema_registry_client_config: Optional[SchemaRegistryClientConfig] = None,
Expand Down Expand Up @@ -110,7 +110,7 @@ def __call__(
class ProtobufDeserializer(Deserializer):
def __init__(
self,
msg_type: Message,
msg_type: Type[Message],
use_integers_for_enums: bool = False,
preserving_proto_field_name: bool = False,
to_dict: bool = True,
Expand Down
28 changes: 3 additions & 25 deletions quixstreams/models/topics/admin.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,10 +7,7 @@
from confluent_kafka.admin import (
AdminClient,
ConfigResource,
KafkaException, # type: ignore
)
from confluent_kafka.admin import (
NewTopic as ConfluentTopic, # type: ignore
KafkaException,
)
from confluent_kafka.admin import (
TopicMetadata as ConfluentTopicMetadata,
Expand All @@ -26,25 +23,6 @@
__all__ = ("TopicAdmin",)


def convert_topic_list(topics: List[Topic]) -> List[ConfluentTopic]:
"""
Converts `Topic`s to `ConfluentTopic`s as required for Confluent's
`AdminClient.create_topic()`.

:param topics: list of `Topic`s
:return: list of confluent_kafka `ConfluentTopic`s
"""
return [
ConfluentTopic(
topic=topic.name,
num_partitions=topic.config.num_partitions,
replication_factor=topic.config.replication_factor,
config=topic.config.extra_config,
)
for topic in topics
]


def confluent_topic_config(topic: str) -> ConfigResource:
return ConfigResource(2, topic)

Expand Down Expand Up @@ -207,12 +185,12 @@ def create_topics(
for topic in topics_to_create:
logger.info(
f'Creating a new topic "{topic.name}" '
f'with config: "{topic.config.as_dict()}"'
f'with config: "{topic.config.as_dict() if topic.config is not None else {}}"'
)

self._finalize_create(
self.admin_client.create_topics(
convert_topic_list(topics_to_create),
[topic.as_newtopic() for topic in topics_to_create],
request_timeout=timeout,
),
finalize_timeout=finalize_timeout,
Expand Down
12 changes: 11 additions & 1 deletion quixstreams/models/topics/manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ class TopicManager:
# Default topic params
default_num_partitions = 1
default_replication_factor = 1
default_extra_config = {}
default_extra_config: dict[str, str] = {}

# Max topic name length for the new topics
_max_topic_name_len = 255
Expand Down Expand Up @@ -211,6 +211,9 @@ def _get_source_topic_config(
topic_name
] or deepcopy(self._non_changelog_topics[topic_name].config)

if topic_config is None:
raise RuntimeError(f"No configuration can be found for topic {topic_name}")

daniil-quix marked this conversation as resolved.
Show resolved Hide resolved
# Copy only certain configuration values from original topic
if extras_imports:
topic_config.extra_config = {
Expand Down Expand Up @@ -475,10 +478,17 @@ def validate_all_topics(self, timeout: Optional[float] = None):

for source_name in self._non_changelog_topics.keys():
source_cfg = actual_configs[source_name]
if source_cfg is None:
raise TopicNotFoundError(f"Topic {source_name} not found on the broker")

# For any changelog topics, validate the amount of partitions and
# replication factor match with the source topic
for changelog in self.changelog_topics.get(source_name, {}).values():
changelog_cfg = actual_configs[changelog.name]
if changelog_cfg is None:
raise TopicNotFoundError(
f"Topic {changelog_cfg} not found on the broker"
)

if changelog_cfg.num_partitions != source_cfg.num_partitions:
raise TopicConfigurationMismatch(
Expand Down
Loading
Loading