diff --git a/docs/changelog/next_release/202.feature.rst b/docs/changelog/next_release/202.feature.rst new file mode 100644 index 000000000..09b93de95 --- /dev/null +++ b/docs/changelog/next_release/202.feature.rst @@ -0,0 +1,12 @@ +Add support of ``Incremental Strategies`` for ``Kafka`` connection. This lets you resume reading data from a Kafka topic starting at the last committed offset from your previous run. + +.. code-block:: python + + reader = DBReader( + connection=Kafka(...), + source="topic_name", + hwm=AutoDetectHWM(name="some_hwm_name", expression="offset"), + ) + + with IncrementalStrategy(): + df = reader.run() diff --git a/docs/connection/db_connection/kafka/read.rst b/docs/connection/db_connection/kafka/read.rst index d502c453e..8b2917943 100644 --- a/docs/connection/db_connection/kafka/read.rst +++ b/docs/connection/db_connection/kafka/read.rst @@ -5,10 +5,6 @@ Reading from Kafka For reading data from Kafka, use :obj:`DBReader ` with specific options (see below). -.. warning:: - - Currently, Kafka does not support :ref:`strategy`. You can only read the **whole** topic. - .. note:: Unlike other connection classes, Kafka always return dataframe with fixed schema diff --git a/onetl/connection/db_connection/kafka/connection.py b/onetl/connection/db_connection/kafka/connection.py index 36aa0fce0..21c29e380 100644 --- a/onetl/connection/db_connection/kafka/connection.py +++ b/onetl/connection/db_connection/kafka/connection.py @@ -14,6 +14,7 @@ from __future__ import annotations +import json import logging from contextlib import closing from typing import TYPE_CHECKING, Any, List, Optional @@ -258,7 +259,7 @@ def check(self): return self @slot - def read_source_as_df( + def read_source_as_df( # noqa: WPS231 self, source: str, columns: list[str] | None = None, @@ -276,7 +277,30 @@ def read_source_as_df( result_options = {f"kafka.{key}": value for key, value in self._get_connection_properties().items()} result_options.update(options.dict(by_alias=True, exclude_none=True)) result_options["subscribe"] = source + + if window and window.expression == "offset": + # the 'including' flag in window values are relevant for batch strategies which are not + # supported by Kafka, therefore we always get offsets including border values + starting_offsets = dict(window.start_from.value) if window.start_from.value else {} + ending_offsets = dict(window.stop_at.value) if window.stop_at.value else {} + + # when the Kafka topic's number of partitions has increased during incremental processing, + # new partitions, which are present in ending_offsets but not in + # starting_offsets, are assigned a default offset (0 in this case). + for partition in ending_offsets: + if partition not in starting_offsets: + starting_offsets[partition] = 0 + + if starting_offsets: + result_options["startingOffsets"] = json.dumps({source: starting_offsets}) + if ending_offsets: + result_options["endingOffsets"] = json.dumps({source: ending_offsets}) + df = self.spark.read.format("kafka").options(**result_options).load() + + if limit is not None: + df = df.limit(limit) + log.info("|%s| Dataframe is successfully created.", self.__class__.__name__) return df @@ -471,6 +495,9 @@ def get_min_max_values( self, source: str, window: Window, + hint: Any | None = None, + where: Any | None = None, + options: KafkaReadOptions | dict | None = None, ) -> tuple[dict[int, int], dict[int, int]]: log.info("|%s| Getting min and max offset values for topic %r ...", self.__class__.__name__, source) diff --git a/onetl/connection/db_connection/kafka/dialect.py b/onetl/connection/db_connection/kafka/dialect.py index 65de18f1a..2de018a54 100644 --- a/onetl/connection/db_connection/kafka/dialect.py +++ b/onetl/connection/db_connection/kafka/dialect.py @@ -27,7 +27,6 @@ NotSupportDFSchema, NotSupportHint, NotSupportWhere, - SupportNameAny, ) if TYPE_CHECKING: @@ -41,11 +40,18 @@ class KafkaDialect( # noqa: WPS215 NotSupportDFSchema, NotSupportHint, NotSupportWhere, - SupportNameAny, DBDialect, ): SUPPORTED_HWM_COLUMNS = {"offset"} + def validate_name(self, value: str) -> str: + if "*" in value or "," in value: + raise ValueError( + f"source/target={value} is not supported by {self.connection.__class__.__name__}. " + f"Provide a singular topic.", + ) + return value + def validate_hwm( self, hwm: HWM | None, diff --git a/onetl/db/db_reader/db_reader.py b/onetl/db/db_reader/db_reader.py index 719208479..b899faab8 100644 --- a/onetl/db/db_reader/db_reader.py +++ b/onetl/db/db_reader/db_reader.py @@ -660,7 +660,7 @@ def _get_hwm_field(self, hwm: HWM) -> StructField: log.info("|%s| Got Spark field: %s", self.__class__.__name__, result) return result - def _calculate_window_and_limit(self) -> tuple[Window | None, int | None]: + def _calculate_window_and_limit(self) -> tuple[Window | None, int | None]: # noqa: WPS231 if not self.hwm: # SnapshotStrategy - always select all the data from source return None, None diff --git a/onetl/strategy/batch_hwm_strategy.py b/onetl/strategy/batch_hwm_strategy.py index 711a4ccff..d4de260e3 100644 --- a/onetl/strategy/batch_hwm_strategy.py +++ b/onetl/strategy/batch_hwm_strategy.py @@ -142,6 +142,9 @@ def check_hwm_increased(self, next_value: Any) -> None: @property def next(self) -> Edge: if self.current.is_set(): + if not hasattr(self.current.value, "__add__"): + raise RuntimeError(f"HWM: {self.hwm!r} cannot be used with Batch strategies") + result = Edge(value=self.current.value + self.step) else: result = Edge(value=self.stop) diff --git a/onetl/strategy/incremental_strategy.py b/onetl/strategy/incremental_strategy.py index 04cc87903..8da323818 100644 --- a/onetl/strategy/incremental_strategy.py +++ b/onetl/strategy/incremental_strategy.py @@ -294,6 +294,40 @@ class IncrementalStrategy(OffsetMixin, HWMStrategy): FROM public.mydata WHERE business_dt > CAST('2021-01-09' AS DATE); -- from HWM-offset (EXCLUDING first row) + Incremental run with :ref:`db-reader` and :ref:`kafka` connection + (by ``offset`` in topic - :etl-entities:`KeyValueHWM `): + + .. code:: python + + from onetl.connection import Kafka + from onetl.db import DBReader + from onetl.strategy import IncrementalStrategy + from onetl.hwm import AutoDetectHWM + + from pyspark.sql import SparkSession + + maven_packages = Kafka.get_packages() + spark = ( + SparkSession.builder.appName("spark-app-name") + .config("spark.jars.packages", ",".join(maven_packages)) + .getOrCreate() + ) + + kafka = Kafka( + addresses=["mybroker:9092", "anotherbroker:9092"], + cluster="my-cluster", + spark=spark, + ) + + reader = DBReader( + connection=kafka, + source="topic_name", + hwm=DBReader.AutoDetectHWM(name="some_hwm_name", expression="offset"), + ) + + with IncrementalStrategy(): + df = reader.run() + Incremental run with :ref:`file-downloader` and ``hwm=FileListHWM(...)``: .. code:: python diff --git a/tests/fixtures/processing/kafka.py b/tests/fixtures/processing/kafka.py index e80371ea9..bddb3490b 100644 --- a/tests/fixtures/processing/kafka.py +++ b/tests/fixtures/processing/kafka.py @@ -1,5 +1,6 @@ from __future__ import annotations +import json import os from typing import TYPE_CHECKING @@ -17,6 +18,17 @@ class KafkaProcessing(BaseProcessing): column_names: list[str] = ["id_int", "text_string", "hwm_int", "float_value"] + def __enter__(self): + return self + + def __exit__(self, _exc_type, _exc_value, _traceback): + return False + + @property + def schema(self) -> str: + # Kafka does not support schemas + return "" + def get_consumer(self): from confluent_kafka import Consumer @@ -111,11 +123,47 @@ def get_expected_df(self, topic: str, num_messages: int = 1, timeout: float = DE def insert_data(self, schema: str, table: str, values: list) -> None: pass + def change_topic_partitions(self, topic: str, num_partitions: int, timeout: float = DEFAULT_TIMEOUT): + from confluent_kafka.admin import NewPartitions + + admin_client = self.get_admin_client() + + if not self.topic_exists(topic): + self.create_topic(topic, num_partitions) + else: + new_partitions = [NewPartitions(topic, num_partitions)] + # change the number of partitions + fs = admin_client.create_partitions(new_partitions, request_timeout=timeout) + + for topic, f in fs.items(): + try: + f.result() + except Exception as e: + raise Exception(f"Failed to update number of partitions for topic '{topic}': {e}") # noqa: WPS454 + + def create_topic(self, topic: str, num_partitions: int, timeout: float = DEFAULT_TIMEOUT): + from confluent_kafka.admin import KafkaException, NewTopic + + admin_client = self.get_admin_client() + topic_config = NewTopic(topic, num_partitions=num_partitions, replication_factor=1) + fs = admin_client.create_topics([topic_config], request_timeout=timeout) + + for topic, f in fs.items(): + try: + f.result() + except Exception as e: + raise KafkaException(f"Error creating topic '{topic}': {e}") + def delete_topic(self, topics: list[str], timeout: float = DEFAULT_TIMEOUT): admin = self.get_admin_client() # https://github.com/confluentinc/confluent-kafka-python/issues/813 admin.delete_topics(topics, request_timeout=timeout) + def insert_pandas_df_into_topic(self, df: pandas.DataFrame, topic: str): + for _, row in df.iterrows(): + message = json.dumps(row.to_dict()) + self.send_message(topic, message.encode("utf-8")) + def topic_exists(self, topic: str, timeout: float = DEFAULT_TIMEOUT) -> bool: admin = self.get_admin_client() topic_metadata = admin.list_topics(timeout=timeout) diff --git a/tests/tests_integration/tests_strategy_integration/tests_incremental_batch_strategy_integration/test_strategy_incremental_batch_kafka.py b/tests/tests_integration/tests_strategy_integration/tests_incremental_batch_strategy_integration/test_strategy_incremental_batch_kafka.py index aad831c01..634f7bbf1 100644 --- a/tests/tests_integration/tests_strategy_integration/tests_incremental_batch_strategy_integration/test_strategy_incremental_batch_kafka.py +++ b/tests/tests_integration/tests_strategy_integration/tests_incremental_batch_strategy_integration/test_strategy_incremental_batch_kafka.py @@ -21,7 +21,7 @@ def test_strategy_kafka_with_batch_strategy_error(strategy, spark): processing = KafkaProcessing() - with strategy(step=10): + with strategy(step=10) as batches: reader = DBReader( connection=Kafka( addresses=[f"{processing.host}:{processing.port}"], @@ -31,5 +31,10 @@ def test_strategy_kafka_with_batch_strategy_error(strategy, spark): table="topic", hwm=DBReader.AutoDetectHWM(name=secrets.token_hex(5), expression="offset"), ) - with pytest.raises(RuntimeError): - reader.run() + # raises as at current version there is no way to distribute step size between kafka partitions + with pytest.raises( + RuntimeError, + match=r"HWM: .* cannot be used with Batch strategies", + ): + for _ in batches: + reader.run() diff --git a/tests/tests_integration/tests_strategy_integration/tests_incremental_strategy_integration/test_strategy_increment_kafka.py b/tests/tests_integration/tests_strategy_integration/tests_incremental_strategy_integration/test_strategy_increment_kafka.py new file mode 100644 index 000000000..b3df91241 --- /dev/null +++ b/tests/tests_integration/tests_strategy_integration/tests_incremental_strategy_integration/test_strategy_increment_kafka.py @@ -0,0 +1,301 @@ +import secrets + +import pytest +from etl_entities.hwm import KeyValueIntHWM +from etl_entities.hwm_store import HWMStoreStackManager + +from onetl.connection import Kafka +from onetl.db import DBReader +from onetl.strategy import IncrementalStrategy + +pytestmark = pytest.mark.kafka + + +@pytest.fixture(name="schema") +def dataframe_schema(): + from pyspark.sql.types import ( + FloatType, + LongType, + StringType, + StructField, + StructType, + ) + + return StructType( + [ + StructField("id_int", LongType(), nullable=True), + StructField("text_string", StringType(), nullable=True), + StructField("hwm_int", LongType(), nullable=True), + StructField("float_value", FloatType(), nullable=True), + ], + ) + + +@pytest.mark.parametrize( + "num_partitions", + [ + None, # default number of partitions is 1 + 5, + 10, + ], +) +def test_kafka_strategy_incremental( + spark, + processing, + schema, + num_partitions, +): + from pyspark.sql.functions import count as spark_count + + hwm_type = KeyValueIntHWM + topic = secrets.token_hex(6) + hwm_name = secrets.token_hex(5) + store = HWMStoreStackManager.get_current() + + kafka = Kafka( + addresses=[f"{processing.host}:{processing.port}"], + cluster="cluster", + spark=spark, + ) + + # change the number of partitions for the Kafka topic to test work for different partitioning cases + if num_partitions is not None: + processing.change_topic_partitions(topic, num_partitions) + + reader = DBReader( + connection=kafka, + source=topic, + hwm=DBReader.AutoDetectHWM(name=hwm_name, expression="offset"), + ) + + # there are 2 spans with a gap between + + # 0..100 + first_span_begin = 0 + first_span_end = 100 + + # 110..210 + second_span_begin = 110 + second_span_end = 210 + + first_span = processing.create_pandas_df(min_id=first_span_begin, max_id=first_span_end) + second_span = processing.create_pandas_df(min_id=second_span_begin, max_id=second_span_end) + + # insert first span + processing.insert_pandas_df_into_topic(first_span, topic) + + # hwm is not in the store + assert store.get_hwm(hwm_name) is None + + # incremental run + with IncrementalStrategy(): + first_df = reader.run() + + hwm = store.get_hwm(hwm_name) + assert hwm is not None + assert isinstance(hwm, hwm_type) + + # check that HWM distribution of messages in partitions matches the distribution in sparkDF + partition_counts = first_df.groupBy("partition").agg(spark_count("*").alias("count")) + partition_count_dict_first_df = {row["partition"]: row["count"] for row in partition_counts.collect()} + assert hwm.value == partition_count_dict_first_df + + # all the data has been read + deserialized_first_df = processing.json_deserialize(first_df, df_schema=schema) + processing.assert_equal_df(df=deserialized_first_df, other_frame=first_span, order_by="id_int") + + # insert second span + processing.insert_pandas_df_into_topic(second_span, topic) + + with IncrementalStrategy(): + second_df = reader.run() + + hwm = store.get_hwm(hwm_name) + + # check that HWM distribution of messages in partitions matches the distribution in sparkDF + combined_df = first_df.union(second_df) + partition_counts_combined = combined_df.groupBy("partition").agg(spark_count("*").alias("count")) + partition_count_dict_combined = {row["partition"]: row["count"] for row in partition_counts_combined.collect()} + assert hwm.value == partition_count_dict_combined + + deserialized_second_df = processing.json_deserialize(second_df, df_schema=schema) + processing.assert_equal_df(df=deserialized_second_df, other_frame=second_span, order_by="id_int") + + +@pytest.mark.parametrize( + "num_partitions", + [ + None, # default number of partitions is 1 + 5, + 10, + ], +) +def test_kafka_strategy_incremental_nothing_to_read(spark, processing, schema, num_partitions): + from pyspark.sql.functions import count as spark_count + + topic = secrets.token_hex(6) + hwm_name = secrets.token_hex(5) + store = HWMStoreStackManager.get_current() + + kafka = Kafka( + addresses=[f"{processing.host}:{processing.port}"], + cluster="cluster", + spark=spark, + ) + + # change the number of partitions for the Kafka topic to test work for different partitioning cases + if num_partitions is not None: + processing.change_topic_partitions(topic, num_partitions) + + reader = DBReader( + connection=kafka, + source=topic, + hwm=DBReader.AutoDetectHWM(name=hwm_name, expression="offset"), + ) + + # 0..50 + first_span_begin = 0 + first_span_end = 50 + # 60..110 + second_span_begin = 60 + second_span_end = 110 + + first_span = processing.create_pandas_df(min_id=first_span_begin, max_id=first_span_end) + second_span = processing.create_pandas_df(min_id=second_span_begin, max_id=second_span_end) + + # no data yet, nothing to read + with IncrementalStrategy(): + df = reader.run() + + assert not df.count() + hwm = store.get_hwm(name=hwm_name) + assert all(value == 0 for value in hwm.value.values()) + + # insert first span + processing.insert_pandas_df_into_topic(first_span, topic) + + # .run() is not called - dataframe still empty - HWM not updated + assert not df.count() + hwm = store.get_hwm(name=hwm_name) + assert all(value == 0 for value in hwm.value.values()) + + # set hwm value to 50 + with IncrementalStrategy(): + first_df = reader.run() + + hwm = store.get_hwm(name=hwm_name) + # check that HWM distribution of messages in partitions matches the distribution in sparkDF + partition_counts = first_df.groupBy("partition").agg(spark_count("*").alias("count")) + partition_count_dict_first_df = {row["partition"]: row["count"] for row in partition_counts.collect()} + assert hwm.value == partition_count_dict_first_df + + deserialized_df = processing.json_deserialize(first_df, df_schema=schema) + processing.assert_equal_df(df=deserialized_df, other_frame=first_span, order_by="id_int") + + # no new data yet, nothing to read + with IncrementalStrategy(): + df = reader.run() + + assert not df.count() + # HWM value is unchanged + hwm = store.get_hwm(name=hwm_name) + assert hwm.value == partition_count_dict_first_df + + # insert second span + processing.insert_pandas_df_into_topic(second_span, topic) + + # .run() is not called - dataframe still empty - HWM not updated + assert not df.count() + # HWM value is unchanged + hwm = store.get_hwm(name=hwm_name) + assert hwm.value == partition_count_dict_first_df + + # read data + with IncrementalStrategy(): + df = reader.run() + + hwm = store.get_hwm(name=hwm_name) + # check that HWM distribution of messages in partitions matches the distribution in sparkDF + combined_df = df.union(first_df) + partition_counts_combined = combined_df.groupBy("partition").agg(spark_count("*").alias("count")) + partition_count_dict_combined = {row["partition"]: row["count"] for row in partition_counts_combined.collect()} + assert hwm.value == partition_count_dict_combined + + deserialized_df = processing.json_deserialize(df, df_schema=schema) + processing.assert_equal_df(df=deserialized_df, other_frame=second_span, order_by="id_int") + + +@pytest.mark.parametrize( + "initial_partitions, additional_partitions", + [ + (3, 2), # starting with 3 partitions, adding 2 more + (5, 1), # starting with 5 partitions, adding 1 more + ], +) +def test_kafka_strategy_incremental_with_new_partition( + spark, + processing, + schema, + initial_partitions, + additional_partitions, +): + from pyspark.sql.functions import count as spark_count + + topic = secrets.token_hex(6) + hwm_name = secrets.token_hex(5) + store = HWMStoreStackManager.get_current() + + kafka = Kafka( + addresses=[f"{processing.host}:{processing.port}"], + cluster="cluster", + spark=spark, + ) + + reader = DBReader( + connection=kafka, + source=topic, + hwm=DBReader.AutoDetectHWM(name=hwm_name, expression="offset"), + ) + + # Initial setup with `initial_partitions` partitions + processing.change_topic_partitions(topic, initial_partitions) + + # 0..50 + first_span_begin = 0 + first_span_end = 100 + + # 60..110 + second_span_begin = 60 + second_span_end = 110 + + first_span = processing.create_pandas_df(min_id=first_span_begin, max_id=first_span_end) + second_span = processing.create_pandas_df(min_id=second_span_begin, max_id=second_span_end) + + processing.insert_pandas_df_into_topic(first_span, topic) + with IncrementalStrategy(): + first_df = reader.run() + + # it is crucial to save dataframe after reading as if number of partitions is altered before executing any subsequent operations, Spark fails to run them due to + # Caused by: java.lang.AssertionError: assertion failed: If startingOffsets contains specific offsets, you must specify all TopicPartitions. + # Use -1 for latest, -2 for earliest. + # Specified: Set(topic1, topic2) Assigned: Set(topic1, topic2, additional_topic3, additional_topic4) + first_df.cache() + + hwm = store.get_hwm(name=hwm_name) + first_run_hwm_keys_num = len(hwm.value.keys()) + + processing.change_topic_partitions(topic, initial_partitions + additional_partitions) + processing.insert_pandas_df_into_topic(second_span, topic) + + with IncrementalStrategy(): + second_df = reader.run() + + hwm = store.get_hwm(name=hwm_name) + second_run_hwm_keys_num = len(hwm.value) + assert first_run_hwm_keys_num + additional_partitions == second_run_hwm_keys_num + + # check that HWM distribution of messages in partitions matches the distribution in sparkDF + combined_df = second_df.union(first_df) + partition_counts_combined = combined_df.groupBy("partition").agg(spark_count("*").alias("count")) + partition_count_dict_combined = {row["partition"]: row["count"] for row in partition_counts_combined.collect()} + assert hwm.value == partition_count_dict_combined diff --git a/tests/tests_unit/test_db/test_db_reader_unit/test_kafka_reader_unit.py b/tests/tests_unit/test_db/test_db_reader_unit/test_kafka_reader_unit.py index acb6ff807..3ff5c9502 100644 --- a/tests/tests_unit/test_db/test_db_reader_unit/test_kafka_reader_unit.py +++ b/tests/tests_unit/test_db/test_db_reader_unit/test_kafka_reader_unit.py @@ -102,3 +102,28 @@ def test_kafka_reader_invalid_hwm_column(spark_mock, hwm_expression): table="table", hwm=DBReader.AutoDetectHWM(name=secrets.token_hex(5), expression=hwm_expression), ) + + +@pytest.mark.parametrize( + "topic, error_message", + [ + ("*", r"source/target=\* is not supported by Kafka. Provide a singular topic."), + ("topic1, topic2", "source/target=topic1, topic2 is not supported by Kafka. Provide a singular topic."), + ], +) +def test_kafka_reader_invalid_source(spark_mock, topic, error_message): + kafka = Kafka( + addresses=["localhost:9092"], + cluster="my_cluster", + spark=spark_mock, + ) + + with pytest.raises( + ValueError, + match=error_message, + ): + DBReader( + connection=kafka, + table=topic, + hwm=DBReader.AutoDetectHWM(name=secrets.token_hex(5), expression="offset"), + )