Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[DOP-7542] - add validation for hwm_column parameter #78

Merged
merged 6 commits into from
Jul 21, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 17 additions & 0 deletions onetl/base/base_db_connection.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,23 @@ def validate_columns(cls, connection: BaseDBConnection, columns: list[str] | Non
If value is invalid
"""

@classmethod
@abstractmethod
def validate_hwm_column(
cls,
connection: BaseDBConnection,
hwm_column: str | None,
) -> str | None:
"""Check if ``hwm_column`` value is valid.

Raises
------
TypeError
If value type is invalid
ValueError
If value is invalid
"""

@classmethod
@abstractmethod
def validate_df_schema(cls, connection: BaseDBConnection, df_schema: StructType | None) -> StructType | None:
Expand Down
3 changes: 3 additions & 0 deletions onetl/connection/db_connection/dialect_mixins/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,9 @@
from onetl.connection.db_connection.dialect_mixins.support_hint_str import (
SupportHintStr,
)
from onetl.connection.db_connection.dialect_mixins.support_hwm_column_str import (
SupportHWMColumnStr,
)
from onetl.connection.db_connection.dialect_mixins.support_hwm_expression_none import (
SupportHWMExpressionNone,
)
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
from __future__ import annotations

from onetl.base import BaseDBConnection


class SupportHWMColumnStr:
@classmethod
def validate_hwm_column(
cls,
connection: BaseDBConnection,
hwm_column: str | None,
) -> str | None:
if not isinstance(hwm_column, str):
raise ValueError(

Check warning on line 14 in onetl/connection/db_connection/dialect_mixins/support_hwm_column_str.py

View check run for this annotation

Codecov / codecov/patch

onetl/connection/db_connection/dialect_mixins/support_hwm_column_str.py#L14

Added line #L14 was not covered by tests
f"{connection.__class__.__name__} requires 'hwm_column' parameter type to be 'str', "
f"got {type(hwm_column)}",
)

return hwm_column
2 changes: 2 additions & 0 deletions onetl/connection/db_connection/greenplum.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@
SupportColumnsList,
SupportDfSchemaNone,
SupportHintNone,
SupportHWMColumnStr,
SupportHWMExpressionStr,
SupportWhereStr,
)
Expand Down Expand Up @@ -447,6 +448,7 @@ class Dialect( # noqa: WPS215
SupportWhereStr,
SupportHintNone,
SupportHWMExpressionStr,
SupportHWMColumnStr,
DBConnection.Dialect,
):
@classmethod
Expand Down
2 changes: 2 additions & 0 deletions onetl/connection/db_connection/hive.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@
SupportColumnsList,
SupportDfSchemaNone,
SupportHintStr,
SupportHWMColumnStr,
SupportHWMExpressionStr,
SupportWhereStr,
)
Expand Down Expand Up @@ -504,6 +505,7 @@ class Dialect( # noqa: WPS215
SupportWhereStr,
SupportHintStr,
SupportHWMExpressionStr,
SupportHWMColumnStr,
DBConnection.Dialect,
):
pass
Expand Down
2 changes: 2 additions & 0 deletions onetl/connection/db_connection/jdbc_connection.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@
SupportColumnsList,
SupportDfSchemaNone,
SupportHintStr,
SupportHWMColumnStr,
SupportHWMExpressionStr,
SupportWhereStr,
)
Expand Down Expand Up @@ -148,6 +149,7 @@ class Dialect( # noqa: WPS215
SupportWhereStr,
SupportHintStr,
SupportHWMExpressionStr,
SupportHWMColumnStr,
DBConnection.Dialect,
):
pass
Expand Down
38 changes: 36 additions & 2 deletions onetl/connection/db_connection/kafka/dialect.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,11 +17,12 @@

import logging

from onetl.connection.db_connection.db_connection import DBConnection
from onetl.connection.db_connection.db_connection import BaseDBConnection, DBConnection
from onetl.connection.db_connection.dialect_mixins import (
SupportColumnsNone,
SupportDfSchemaNone,
SupportHintNone,
SupportHWMExpressionNone,
SupportTableWithoutDBSchema,
SupportWhereNone,
)
Expand All @@ -35,6 +36,39 @@
SupportHintNone,
SupportWhereNone,
SupportTableWithoutDBSchema,
SupportHWMExpressionNone,
DBConnection.Dialect,
):
pass
valid_hwm_columns = {"offset", "timestamp"}

@classmethod
def validate_hwm_column(
cls,
connection: BaseDBConnection,
hwm_column: str | None,
) -> str | None:
if not isinstance(hwm_column, str):
raise ValueError(

Check warning on line 51 in onetl/connection/db_connection/kafka/dialect.py

View check run for this annotation

Codecov / codecov/patch

onetl/connection/db_connection/kafka/dialect.py#L51

Added line #L51 was not covered by tests
f"{connection.__class__.__name__} requires 'hwm_column' parameter type to be 'str', "
f"got {type(hwm_column)}",
)
dolfinus marked this conversation as resolved.
Show resolved Hide resolved

cls.validate_column(connection, hwm_column)

return hwm_column

@classmethod
def validate_column(cls, connection: BaseDBConnection, column: str) -> None:
if column not in cls.valid_hwm_columns:
raise ValueError(f"{column} is not a valid hwm column. Valid options are: {cls.valid_hwm_columns}")
if column == "timestamp":
maxim-lixakov marked this conversation as resolved.
Show resolved Hide resolved
# Spark version less 3.x does not support reading from Kafka with the timestamp parameter
cls._check_spark_version(connection)

@staticmethod
def _check_spark_version(connection: BaseDBConnection) -> None:
spark_version = connection.spark.version # type: ignore[attr-defined]
major_version = int(spark_version.split(".")[0])

if major_version < 3:
raise ValueError(f"Spark version must be 3.x for the timestamp column. Current version is: {spark_version}")
4 changes: 3 additions & 1 deletion onetl/connection/db_connection/mongodb.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@
from onetl.connection.db_connection.dialect_mixins import (
SupportColumnsNone,
SupportDfSchemaStruct,
SupportHWMColumnStr,
SupportHWMExpressionNone,
)
from onetl.connection.db_connection.dialect_mixins.support_table_without_dbschema import (
Expand Down Expand Up @@ -401,11 +402,12 @@ class Config:
known_options = KNOWN_WRITE_OPTIONS
extra = "allow"

class Dialect(
class Dialect( # noqa: WPS215
SupportTableWithoutDBSchema,
SupportHWMExpressionNone,
SupportColumnsNone,
SupportDfSchemaStruct,
SupportHWMColumnStr,
DBConnection.Dialect,
):
_compare_statements: ClassVar[Dict[Callable, str]] = {
Expand Down
2 changes: 2 additions & 0 deletions onetl/connection/db_connection/postgres.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
SupportColumnsList,
SupportDfSchemaNone,
SupportHintNone,
SupportHWMColumnStr,
SupportHWMExpressionStr,
SupportWhereStr,
)
Expand Down Expand Up @@ -134,6 +135,7 @@ class Dialect( # noqa: WPS215
SupportDfSchemaNone,
SupportWhereStr,
SupportHWMExpressionStr,
SupportHWMColumnStr,
SupportHintNone,
DBConnection.Dialect,
):
Expand Down
5 changes: 4 additions & 1 deletion onetl/db/db_reader/db_reader.py
Original file line number Diff line number Diff line change
Expand Up @@ -373,9 +373,9 @@ def validate_hwm_column(cls, values: dict) -> dict:
hwm_column: str | tuple[str, str] | Column | None = values.get("hwm_column")
df_schema: StructType | None = values.get("df_schema")
hwm_expression: str | None = values.get("hwm_expression")
connection: BaseDBConnection = values["connection"]

if hwm_column is None or isinstance(hwm_column, Column):
# nothing to validate
return values

if not hwm_expression and not isinstance(hwm_column, str):
Expand All @@ -398,6 +398,9 @@ def validate_hwm_column(cls, values: dict) -> dict:
values["hwm_column"] = Column(name=hwm_column) # type: ignore
values["hwm_expression"] = hwm_expression

dialect = connection.Dialect
dialect.validate_hwm_column(connection, hwm_column) # type: ignore
Comment on lines 398 to +402
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
values["hwm_column"] = Column(name=hwm_column) # type: ignore
values["hwm_expression"] = hwm_expression
dialect = connection.Dialect
dialect.validate_hwm_column(connection, hwm_column) # type: ignore
dialect = connection.Dialect
dialect.validate_hwm_column(connection, hwm_column) # type: ignore
values["hwm_column"] = Column(name=hwm_column) # type: ignore
values["hwm_expression"] = hwm_expression

я бы сделал так, а то странно проводить валидацию после присвоения в values


return values

@root_validator(pre=True) # noqa: WPS231
Expand Down
4 changes: 1 addition & 3 deletions tests/fixtures/spark_mock.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,9 +5,7 @@

@pytest.fixture(
scope="function",
params=[
pytest.param("mock", marks=[pytest.mark.db_connection, pytest.mark.connection]),
],
params=[pytest.param("mock", marks=[pytest.mark.db_connection, pytest.mark.connection])],
)
def spark_mock():
from pyspark.sql import SparkSession
Expand Down
Original file line number Diff line number Diff line change
@@ -1,4 +1,7 @@
from unittest.mock import patch

import pytest
from etl_entities import Column

from onetl.connection import Kafka
from onetl.db import DBReader
Expand Down Expand Up @@ -86,3 +89,63 @@ def test_kafka_reader_unsupported_parameters(spark_mock, df_schema):
table="table",
df_schema=df_schema,
)


def test_kafka_reader_valid_hwm_column(spark_mock):
kafka = Kafka(
addresses=["localhost:9092"],
cluster="my_cluster",
spark=spark_mock,
)

DBReader(
connection=kafka,
table="table",
hwm_column="offset",
dolfinus marked this conversation as resolved.
Show resolved Hide resolved
)

DBReader(
connection=kafka,
table="table",
hwm_column=Column(name="offset"),
)


def test_kafka_reader_hwm_column_by_version(spark_mock):
kafka = Kafka(
addresses=["localhost:9092"],
cluster="my_cluster",
spark=spark_mock,
)
with patch.object(spark_mock, "version", new="3.3.0"):
DBReader(
connection=kafka,
table="table",
hwm_column="timestamp",
)
with patch.object(spark_mock, "version", new="2.3.0"):
with pytest.raises(ValueError, match="Spark version must be 3.x"):
DBReader(
connection=kafka,
table="table",
hwm_column="timestamp",
)


@pytest.mark.parametrize("hwm_column", ["unknown", '("some", "thing")'])
def test_kafka_reader_invalid_hwm_column(spark_mock, hwm_column):
kafka = Kafka(
addresses=["localhost:9092"],
cluster="my_cluster",
spark=spark_mock,
)

with pytest.raises(
ValueError,
match="is not a valid hwm column",
):
DBReader(
connection=kafka,
table="table",
hwm_column=hwm_column,
)
Loading