From b4cb653428340f0c653a5ca6122b884f47c7594a Mon Sep 17 00:00:00 2001 From: maxim-lixakov Date: Wed, 19 Jul 2023 17:33:07 +0300 Subject: [PATCH 1/4] [DOP-7542] - add validation for hwm_column and hwm_expression in dialects' classes --- onetl/base/base_db_connection.py | 19 ++++- .../connection/db_connection/kafka/dialect.py | 51 +++++++++++++- onetl/db/db_reader/db_reader.py | 4 ++ tests/fixtures/spark_mock.py | 6 +- .../test_kafka_reader_unit.py | 69 +++++++++++++++++++ 5 files changed, 144 insertions(+), 5 deletions(-) diff --git a/onetl/base/base_db_connection.py b/onetl/base/base_db_connection.py index 12c624576..5b53cc8df 100644 --- a/onetl/base/base_db_connection.py +++ b/onetl/base/base_db_connection.py @@ -17,7 +17,7 @@ from abc import ABC, abstractmethod from typing import TYPE_CHECKING, Any, Callable -from etl_entities import Table +from etl_entities import Column, Table from onetl.base.base_connection import BaseConnection from onetl.hwm import Statement @@ -63,6 +63,23 @@ def validate_columns(cls, connection: BaseDBConnection, columns: list[str] | Non If value is invalid """ + @classmethod + @abstractmethod + def validate_hwm_column( + cls, + connection: BaseDBConnection, + hwm_column: str | tuple[str, str] | Column | None, + ) -> str | tuple[str, str] | Column | None: + """Check if ``hwm_column`` value is valid. + + Raises + ------ + TypeError + If value type is invalid + ValueError + If value is invalid + """ + @classmethod @abstractmethod def validate_df_schema(cls, connection: BaseDBConnection, df_schema: StructType | None) -> StructType | None: diff --git a/onetl/connection/db_connection/kafka/dialect.py b/onetl/connection/db_connection/kafka/dialect.py index 4464381e4..91ac6a4e1 100644 --- a/onetl/connection/db_connection/kafka/dialect.py +++ b/onetl/connection/db_connection/kafka/dialect.py @@ -2,11 +2,14 @@ import logging -from onetl.connection.db_connection.db_connection import DBConnection +from etl_entities import Column + +from onetl.connection.db_connection.db_connection import BaseDBConnection, DBConnection from onetl.connection.db_connection.dialect_mixins import ( SupportColumnsNone, SupportDfSchemaNone, SupportHintNone, + SupportHWMExpressionNone, SupportTableWithoutDBSchema, SupportWhereNone, ) @@ -20,6 +23,50 @@ class KafkaDialect( # noqa: WPS215 SupportHintNone, SupportWhereNone, SupportTableWithoutDBSchema, + SupportHWMExpressionNone, DBConnection.Dialect, ): - pass + valid_hwm_columns = {"offset", "timestamp"} + + @classmethod + def validate_hwm_column( + cls, + connection: BaseDBConnection, + hwm_column: str | tuple[str, str] | Column | None, + ) -> str | tuple[str, str] | Column | None: + if isinstance(hwm_column, str): + cls.validate_single_column(connection, hwm_column) + elif isinstance(hwm_column, tuple): + cls.validate_tuple_columns(connection, hwm_column) + elif isinstance(hwm_column, Column): + cls.validate_column_class(connection, hwm_column) + + return hwm_column + + @classmethod + def validate_single_column(cls, connection: BaseDBConnection, column: str) -> None: + cls.validate_column(connection, column) + + @classmethod + def validate_tuple_columns(cls, connection: BaseDBConnection, columns: tuple[str, str]) -> None: + for column in columns: + cls.validate_column(connection, column) + + @classmethod + def validate_column_class(cls, connection: BaseDBConnection, column: Column) -> None: + cls.validate_column(connection, column.name) + + @classmethod + def validate_column(cls, connection: BaseDBConnection, column: str) -> None: + if column not in cls.valid_hwm_columns: + raise ValueError(f"{column} is not a valid hwm column. Valid options are: {cls.valid_hwm_columns}") + if column == "timestamp": + cls.check_spark_version(connection) + + @staticmethod + def check_spark_version(connection: BaseDBConnection) -> None: + spark_version = connection.spark.version # type: ignore[attr-defined] + major_version = int(spark_version.split(".")[0]) + + if major_version < 3: + raise ValueError(f"Spark version must be 3.x for the timestamp column. Current version is: {spark_version}") diff --git a/onetl/db/db_reader/db_reader.py b/onetl/db/db_reader/db_reader.py index 6c9f793a4..f03eb68b8 100644 --- a/onetl/db/db_reader/db_reader.py +++ b/onetl/db/db_reader/db_reader.py @@ -398,6 +398,10 @@ def validate_hwm_column(cls, values: dict) -> dict: values["hwm_column"] = Column(name=hwm_column) # type: ignore values["hwm_expression"] = hwm_expression + connection: BaseDBConnection = values["connection"] + dialect = connection.Dialect + dialect.validate_hwm_column(connection, hwm_column) + return values @root_validator(pre=True) # noqa: WPS231 diff --git a/tests/fixtures/spark_mock.py b/tests/fixtures/spark_mock.py index d44633390..d25a66010 100644 --- a/tests/fixtures/spark_mock.py +++ b/tests/fixtures/spark_mock.py @@ -6,13 +6,15 @@ @pytest.fixture( scope="function", params=[ - pytest.param("mock", marks=[pytest.mark.db_connection, pytest.mark.connection]), + pytest.param("2.3.0", id="Spark 2.3.0"), + pytest.param("3.3.0", id="Spark 3.3.0"), ], ) -def spark_mock(): +def spark_mock(request): from pyspark.sql import SparkSession spark = Mock(spec=SparkSession) + spark.version = request.param # sets the version according to the params spark.sparkContext = Mock() spark.sparkContext.appName = "abc" return spark diff --git a/tests/tests_unit/test_db/test_db_reader_unit/test_kafka_reader_unit.py b/tests/tests_unit/test_db/test_db_reader_unit/test_kafka_reader_unit.py index e102897e9..68257b405 100644 --- a/tests/tests_unit/test_db/test_db_reader_unit/test_kafka_reader_unit.py +++ b/tests/tests_unit/test_db/test_db_reader_unit/test_kafka_reader_unit.py @@ -1,4 +1,5 @@ import pytest +from etl_entities import Column from onetl.connection import Kafka from onetl.db import DBReader @@ -86,3 +87,71 @@ def test_kafka_reader_unsupported_parameters(spark_mock, df_schema): table="table", df_schema=df_schema, ) + + +def test_kafka_reader_valid_hwm_column(spark_mock): + kafka = Kafka( + addresses=["localhost:9092"], + cluster="my_cluster", + spark=spark_mock, + ) + + try: + DBReader( + connection=kafka, + table="table", + hwm_column="offset", + ) + + DBReader( + connection=kafka, + table="table", + hwm_column=Column(name="offset"), + ) + except ValueError: + pytest.fail("ValueError for hwm_column raised unexpectedly!") + + if spark_mock.version.startswith("3."): + try: + DBReader( + connection=kafka, + table="table", + hwm_column="timestamp", + ) + except ValueError: + pytest.fail("ValueError for hwm_column raised unexpectedly!") + else: + with pytest.raises(ValueError, match="Spark version must be 3.x"): + DBReader( + connection=kafka, + table="table", + hwm_column="timestamp", + ) + + +def test_kafka_reader_invalid_hwm_column(spark_mock): + kafka = Kafka( + addresses=["localhost:9092"], + cluster="my_cluster", + spark=spark_mock, + ) + + with pytest.raises( + ValueError, + match="is not a valid hwm column", + ): + DBReader( + connection=kafka, + table="table", + hwm_column="unknown", + ) + + with pytest.raises( + ValueError, + match="is not a valid hwm column", + ): + DBReader( + connection=kafka, + table="table", + hwm_column=("some", "thing"), + ) From 321d9647a74dfdbae39e9126e4b9591c7796eaec Mon Sep 17 00:00:00 2001 From: Maxim Liksakov <67663774+maxim-lixakov@users.noreply.github.com> Date: Thu, 20 Jul 2023 11:31:08 +0300 Subject: [PATCH 2/4] Update onetl/base/base_db_connection.py Co-authored-by: Maxim Martynov --- onetl/base/base_db_connection.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/onetl/base/base_db_connection.py b/onetl/base/base_db_connection.py index 5b53cc8df..2dae3387c 100644 --- a/onetl/base/base_db_connection.py +++ b/onetl/base/base_db_connection.py @@ -68,7 +68,7 @@ def validate_columns(cls, connection: BaseDBConnection, columns: list[str] | Non def validate_hwm_column( cls, connection: BaseDBConnection, - hwm_column: str | tuple[str, str] | Column | None, + hwm_column: str | None, ) -> str | tuple[str, str] | Column | None: """Check if ``hwm_column`` value is valid. From 7ed73f0d9ad9fad459e14c8f9272d6efa882db00 Mon Sep 17 00:00:00 2001 From: maxim-lixakov Date: Thu, 20 Jul 2023 15:23:37 +0300 Subject: [PATCH 3/4] [DOP-7542] - corrected tests and validation for Kafka.Dialect.validate_hwm_column --- onetl/base/base_db_connection.py | 4 +- .../connection/db_connection/kafka/dialect.py | 37 ++++-------- onetl/db/db_reader/db_reader.py | 5 +- tests/fixtures/spark_mock.py | 8 +-- .../test_kafka_reader_unit.py | 58 +++++++++---------- 5 files changed, 44 insertions(+), 68 deletions(-) diff --git a/onetl/base/base_db_connection.py b/onetl/base/base_db_connection.py index 2dae3387c..64594b731 100644 --- a/onetl/base/base_db_connection.py +++ b/onetl/base/base_db_connection.py @@ -17,7 +17,7 @@ from abc import ABC, abstractmethod from typing import TYPE_CHECKING, Any, Callable -from etl_entities import Column, Table +from etl_entities import Table from onetl.base.base_connection import BaseConnection from onetl.hwm import Statement @@ -69,7 +69,7 @@ def validate_hwm_column( cls, connection: BaseDBConnection, hwm_column: str | None, - ) -> str | tuple[str, str] | Column | None: + ) -> str | None: """Check if ``hwm_column`` value is valid. Raises diff --git a/onetl/connection/db_connection/kafka/dialect.py b/onetl/connection/db_connection/kafka/dialect.py index 91ac6a4e1..963edce65 100644 --- a/onetl/connection/db_connection/kafka/dialect.py +++ b/onetl/connection/db_connection/kafka/dialect.py @@ -2,8 +2,6 @@ import logging -from etl_entities import Column - from onetl.connection.db_connection.db_connection import BaseDBConnection, DBConnection from onetl.connection.db_connection.dialect_mixins import ( SupportColumnsNone, @@ -32,39 +30,28 @@ class KafkaDialect( # noqa: WPS215 def validate_hwm_column( cls, connection: BaseDBConnection, - hwm_column: str | tuple[str, str] | Column | None, - ) -> str | tuple[str, str] | Column | None: - if isinstance(hwm_column, str): - cls.validate_single_column(connection, hwm_column) - elif isinstance(hwm_column, tuple): - cls.validate_tuple_columns(connection, hwm_column) - elif isinstance(hwm_column, Column): - cls.validate_column_class(connection, hwm_column) - - return hwm_column - - @classmethod - def validate_single_column(cls, connection: BaseDBConnection, column: str) -> None: - cls.validate_column(connection, column) + hwm_column: str | None, + ) -> str | None: + if not isinstance(hwm_column, str): + raise ValueError( + f"{connection.__class__.__name__} requires 'hwm_column' parameter type to be 'str', " + f"got {type(hwm_column)}", + ) - @classmethod - def validate_tuple_columns(cls, connection: BaseDBConnection, columns: tuple[str, str]) -> None: - for column in columns: - cls.validate_column(connection, column) + cls.validate_column(connection, hwm_column) - @classmethod - def validate_column_class(cls, connection: BaseDBConnection, column: Column) -> None: - cls.validate_column(connection, column.name) + return hwm_column @classmethod def validate_column(cls, connection: BaseDBConnection, column: str) -> None: if column not in cls.valid_hwm_columns: raise ValueError(f"{column} is not a valid hwm column. Valid options are: {cls.valid_hwm_columns}") if column == "timestamp": - cls.check_spark_version(connection) + # Spark version less 3.x does not support reading from Kafka with the timestamp parameter + cls._check_spark_version(connection) @staticmethod - def check_spark_version(connection: BaseDBConnection) -> None: + def _check_spark_version(connection: BaseDBConnection) -> None: spark_version = connection.spark.version # type: ignore[attr-defined] major_version = int(spark_version.split(".")[0]) diff --git a/onetl/db/db_reader/db_reader.py b/onetl/db/db_reader/db_reader.py index f03eb68b8..91dff4e2d 100644 --- a/onetl/db/db_reader/db_reader.py +++ b/onetl/db/db_reader/db_reader.py @@ -373,9 +373,9 @@ def validate_hwm_column(cls, values: dict) -> dict: hwm_column: str | tuple[str, str] | Column | None = values.get("hwm_column") df_schema: StructType | None = values.get("df_schema") hwm_expression: str | None = values.get("hwm_expression") + connection: BaseDBConnection = values["connection"] if hwm_column is None or isinstance(hwm_column, Column): - # nothing to validate return values if not hwm_expression and not isinstance(hwm_column, str): @@ -398,9 +398,8 @@ def validate_hwm_column(cls, values: dict) -> dict: values["hwm_column"] = Column(name=hwm_column) # type: ignore values["hwm_expression"] = hwm_expression - connection: BaseDBConnection = values["connection"] dialect = connection.Dialect - dialect.validate_hwm_column(connection, hwm_column) + dialect.validate_hwm_column(connection, hwm_column) # type: ignore return values diff --git a/tests/fixtures/spark_mock.py b/tests/fixtures/spark_mock.py index d25a66010..ebacefb15 100644 --- a/tests/fixtures/spark_mock.py +++ b/tests/fixtures/spark_mock.py @@ -5,16 +5,12 @@ @pytest.fixture( scope="function", - params=[ - pytest.param("2.3.0", id="Spark 2.3.0"), - pytest.param("3.3.0", id="Spark 3.3.0"), - ], + params=[pytest.param("mock", marks=[pytest.mark.db_connection, pytest.mark.connection])], ) -def spark_mock(request): +def spark_mock(): from pyspark.sql import SparkSession spark = Mock(spec=SparkSession) - spark.version = request.param # sets the version according to the params spark.sparkContext = Mock() spark.sparkContext.appName = "abc" return spark diff --git a/tests/tests_unit/test_db/test_db_reader_unit/test_kafka_reader_unit.py b/tests/tests_unit/test_db/test_db_reader_unit/test_kafka_reader_unit.py index 68257b405..372e2a24e 100644 --- a/tests/tests_unit/test_db/test_db_reader_unit/test_kafka_reader_unit.py +++ b/tests/tests_unit/test_db/test_db_reader_unit/test_kafka_reader_unit.py @@ -1,3 +1,5 @@ +from unittest.mock import patch + import pytest from etl_entities import Column @@ -96,31 +98,32 @@ def test_kafka_reader_valid_hwm_column(spark_mock): spark=spark_mock, ) - try: - DBReader( - connection=kafka, - table="table", - hwm_column="offset", - ) + DBReader( + connection=kafka, + table="table", + hwm_column="offset", + ) + + DBReader( + connection=kafka, + table="table", + hwm_column=Column(name="offset"), + ) + +def test_kafka_reader_hwm_column_by_version(spark_mock): + kafka = Kafka( + addresses=["localhost:9092"], + cluster="my_cluster", + spark=spark_mock, + ) + with patch.object(spark_mock, "version", new="3.3.0"): DBReader( connection=kafka, table="table", - hwm_column=Column(name="offset"), + hwm_column="timestamp", ) - except ValueError: - pytest.fail("ValueError for hwm_column raised unexpectedly!") - - if spark_mock.version.startswith("3."): - try: - DBReader( - connection=kafka, - table="table", - hwm_column="timestamp", - ) - except ValueError: - pytest.fail("ValueError for hwm_column raised unexpectedly!") - else: + with patch.object(spark_mock, "version", new="2.3.0"): with pytest.raises(ValueError, match="Spark version must be 3.x"): DBReader( connection=kafka, @@ -129,7 +132,8 @@ def test_kafka_reader_valid_hwm_column(spark_mock): ) -def test_kafka_reader_invalid_hwm_column(spark_mock): +@pytest.mark.parametrize("hwm_column", ["unknown", '("some", "thing")']) +def test_kafka_reader_invalid_hwm_column(spark_mock, hwm_column): kafka = Kafka( addresses=["localhost:9092"], cluster="my_cluster", @@ -143,15 +147,5 @@ def test_kafka_reader_invalid_hwm_column(spark_mock): DBReader( connection=kafka, table="table", - hwm_column="unknown", - ) - - with pytest.raises( - ValueError, - match="is not a valid hwm column", - ): - DBReader( - connection=kafka, - table="table", - hwm_column=("some", "thing"), + hwm_column=hwm_column, ) From 869b0bd447124222c6353ac9fd991b2ab645c26a Mon Sep 17 00:00:00 2001 From: maxim-lixakov Date: Thu, 20 Jul 2023 16:04:37 +0300 Subject: [PATCH 4/4] [DOP-7542] - add SupportHWMColumnStr mixin --- .../db_connection/dialect_mixins/__init__.py | 3 +++ .../dialect_mixins/support_hwm_column_str.py | 19 +++++++++++++++++++ onetl/connection/db_connection/greenplum.py | 2 ++ onetl/connection/db_connection/hive.py | 2 ++ .../db_connection/jdbc_connection.py | 2 ++ onetl/connection/db_connection/mongodb.py | 4 +++- onetl/connection/db_connection/postgres.py | 2 ++ 7 files changed, 33 insertions(+), 1 deletion(-) create mode 100644 onetl/connection/db_connection/dialect_mixins/support_hwm_column_str.py diff --git a/onetl/connection/db_connection/dialect_mixins/__init__.py b/onetl/connection/db_connection/dialect_mixins/__init__.py index a1822af14..1eb0d7f17 100644 --- a/onetl/connection/db_connection/dialect_mixins/__init__.py +++ b/onetl/connection/db_connection/dialect_mixins/__init__.py @@ -16,6 +16,9 @@ from onetl.connection.db_connection.dialect_mixins.support_hint_str import ( SupportHintStr, ) +from onetl.connection.db_connection.dialect_mixins.support_hwm_column_str import ( + SupportHWMColumnStr, +) from onetl.connection.db_connection.dialect_mixins.support_hwm_expression_none import ( SupportHWMExpressionNone, ) diff --git a/onetl/connection/db_connection/dialect_mixins/support_hwm_column_str.py b/onetl/connection/db_connection/dialect_mixins/support_hwm_column_str.py new file mode 100644 index 000000000..a27e6af76 --- /dev/null +++ b/onetl/connection/db_connection/dialect_mixins/support_hwm_column_str.py @@ -0,0 +1,19 @@ +from __future__ import annotations + +from onetl.base import BaseDBConnection + + +class SupportHWMColumnStr: + @classmethod + def validate_hwm_column( + cls, + connection: BaseDBConnection, + hwm_column: str | None, + ) -> str | None: + if not isinstance(hwm_column, str): + raise ValueError( + f"{connection.__class__.__name__} requires 'hwm_column' parameter type to be 'str', " + f"got {type(hwm_column)}", + ) + + return hwm_column diff --git a/onetl/connection/db_connection/greenplum.py b/onetl/connection/db_connection/greenplum.py index 61ea887b6..3a28c6dd2 100644 --- a/onetl/connection/db_connection/greenplum.py +++ b/onetl/connection/db_connection/greenplum.py @@ -35,6 +35,7 @@ SupportColumnsList, SupportDfSchemaNone, SupportHintNone, + SupportHWMColumnStr, SupportHWMExpressionStr, SupportWhereStr, ) @@ -447,6 +448,7 @@ class Dialect( # noqa: WPS215 SupportWhereStr, SupportHintNone, SupportHWMExpressionStr, + SupportHWMColumnStr, DBConnection.Dialect, ): @classmethod diff --git a/onetl/connection/db_connection/hive.py b/onetl/connection/db_connection/hive.py index 94397f2a4..0999fae2b 100644 --- a/onetl/connection/db_connection/hive.py +++ b/onetl/connection/db_connection/hive.py @@ -29,6 +29,7 @@ SupportColumnsList, SupportDfSchemaNone, SupportHintStr, + SupportHWMColumnStr, SupportHWMExpressionStr, SupportWhereStr, ) @@ -504,6 +505,7 @@ class Dialect( # noqa: WPS215 SupportWhereStr, SupportHintStr, SupportHWMExpressionStr, + SupportHWMColumnStr, DBConnection.Dialect, ): pass diff --git a/onetl/connection/db_connection/jdbc_connection.py b/onetl/connection/db_connection/jdbc_connection.py index f6438296b..19af6a9a3 100644 --- a/onetl/connection/db_connection/jdbc_connection.py +++ b/onetl/connection/db_connection/jdbc_connection.py @@ -29,6 +29,7 @@ SupportColumnsList, SupportDfSchemaNone, SupportHintStr, + SupportHWMColumnStr, SupportHWMExpressionStr, SupportWhereStr, ) @@ -148,6 +149,7 @@ class Dialect( # noqa: WPS215 SupportWhereStr, SupportHintStr, SupportHWMExpressionStr, + SupportHWMColumnStr, DBConnection.Dialect, ): pass diff --git a/onetl/connection/db_connection/mongodb.py b/onetl/connection/db_connection/mongodb.py index a1e4f1763..15f34ca0d 100644 --- a/onetl/connection/db_connection/mongodb.py +++ b/onetl/connection/db_connection/mongodb.py @@ -30,6 +30,7 @@ from onetl.connection.db_connection.dialect_mixins import ( SupportColumnsNone, SupportDfSchemaStruct, + SupportHWMColumnStr, SupportHWMExpressionNone, ) from onetl.connection.db_connection.dialect_mixins.support_table_without_dbschema import ( @@ -401,11 +402,12 @@ class Config: known_options = KNOWN_WRITE_OPTIONS extra = "allow" - class Dialect( + class Dialect( # noqa: WPS215 SupportTableWithoutDBSchema, SupportHWMExpressionNone, SupportColumnsNone, SupportDfSchemaStruct, + SupportHWMColumnStr, DBConnection.Dialect, ): _compare_statements: ClassVar[Dict[Callable, str]] = { diff --git a/onetl/connection/db_connection/postgres.py b/onetl/connection/db_connection/postgres.py index 76606e5e6..be068475b 100644 --- a/onetl/connection/db_connection/postgres.py +++ b/onetl/connection/db_connection/postgres.py @@ -22,6 +22,7 @@ SupportColumnsList, SupportDfSchemaNone, SupportHintNone, + SupportHWMColumnStr, SupportHWMExpressionStr, SupportWhereStr, ) @@ -134,6 +135,7 @@ class Dialect( # noqa: WPS215 SupportDfSchemaNone, SupportWhereStr, SupportHWMExpressionStr, + SupportHWMColumnStr, SupportHintNone, DBConnection.Dialect, ):