diff --git a/.github/workflows/data/core/matrix.yml b/.github/workflows/data/core/matrix.yml index d20f074ab..8362ef500 100644 --- a/.github/workflows/data/core/matrix.yml +++ b/.github/workflows/data/core/matrix.yml @@ -8,7 +8,7 @@ min: &min max: &max spark-version: 3.5.1 pydantic-version: 2 - python-version: '3.12' + python-version: '3.13.0-beta.4-dev' java-version: 20 os: ubuntu-latest diff --git a/.github/workflows/data/ftp/matrix.yml b/.github/workflows/data/ftp/matrix.yml index d01c39029..878f9b779 100644 --- a/.github/workflows/data/ftp/matrix.yml +++ b/.github/workflows/data/ftp/matrix.yml @@ -5,7 +5,7 @@ min: &min max: &max pydantic-version: 2 - python-version: '3.12' + python-version: '3.13.0-beta.4' os: ubuntu-latest latest: &latest diff --git a/.github/workflows/data/ftps/matrix.yml b/.github/workflows/data/ftps/matrix.yml index efe28e79a..40ec8fc9a 100644 --- a/.github/workflows/data/ftps/matrix.yml +++ b/.github/workflows/data/ftps/matrix.yml @@ -5,7 +5,7 @@ min: &min max: &max pydantic-version: 2 - python-version: '3.12' + python-version: '3.13.0-beta.4' os: ubuntu-latest latest: &latest diff --git a/.github/workflows/data/hdfs/matrix.yml b/.github/workflows/data/hdfs/matrix.yml index 6d8156c50..45cbc1d96 100644 --- a/.github/workflows/data/hdfs/matrix.yml +++ b/.github/workflows/data/hdfs/matrix.yml @@ -10,7 +10,7 @@ max: &max hadoop-version: hadoop3-hdfs spark-version: 3.5.1 pydantic-version: 2 - python-version: '3.12' + python-version: '3.13.0-beta.4' java-version: 20 os: ubuntu-latest diff --git a/.github/workflows/data/s3/matrix.yml b/.github/workflows/data/s3/matrix.yml index d9b9338f8..df7f08b38 100644 --- a/.github/workflows/data/s3/matrix.yml +++ b/.github/workflows/data/s3/matrix.yml @@ -12,7 +12,7 @@ max: &max minio-version: 2024.4.18 spark-version: 3.5.1 pydantic-version: 2 - python-version: '3.12' + python-version: '3.13.0-beta.4' java-version: 20 os: ubuntu-latest diff --git a/.github/workflows/data/sftp/matrix.yml b/.github/workflows/data/sftp/matrix.yml index a32f6f823..a57f6dfe3 100644 --- a/.github/workflows/data/sftp/matrix.yml +++ b/.github/workflows/data/sftp/matrix.yml @@ -5,7 +5,7 @@ min: &min max: &max pydantic-version: 2 - python-version: '3.12' + python-version: '3.13.0-beta.4' os: ubuntu-latest latest: &latest diff --git a/.github/workflows/data/webdav/matrix.yml b/.github/workflows/data/webdav/matrix.yml index fb76e3282..57423e616 100644 --- a/.github/workflows/data/webdav/matrix.yml +++ b/.github/workflows/data/webdav/matrix.yml @@ -5,7 +5,7 @@ min: &min max: &max pydantic-version: 2 - python-version: '3.12' + python-version: '3.13.0-beta.4' os: ubuntu-latest latest: &latest diff --git a/.github/workflows/test-sftp.yml b/.github/workflows/test-sftp.yml index eaa5e5a43..bff893355 100644 --- a/.github/workflows/test-sftp.yml +++ b/.github/workflows/test-sftp.yml @@ -58,7 +58,7 @@ jobs: - name: Install dependencies run: | - pip install -I -r requirements/core.txt -r requirements/sftp.txt -r requirements/tests/base.txt -r requirements/tests/pydantic-${{ inputs.pydantic-version }}.txt + pip install -I -r requirements/core.txt -r requirements/sftp.txt -r requirements/tests/base.txt -r requirements/tests/pydantic-${{ inputs.pydantic-version }}.txt cffi==1.17.0rc1 - name: Run tests run: | diff --git a/.github/workflows/test-webdav.yml b/.github/workflows/test-webdav.yml index 34a943260..cfff17712 100644 --- a/.github/workflows/test-webdav.yml +++ b/.github/workflows/test-webdav.yml @@ -33,6 +33,12 @@ jobs: with: python-version: ${{ inputs.python-version }} + - name: Set up lxml libs + if: runner.os == 'Linux' + run: | + sudo apt-get update + sudo apt-get install --no-install-recommends libxml2-dev libxslt-dev + - name: Cache pip uses: actions/cache@v4 if: inputs.with-cache diff --git a/onetl/base/base_file_format.py b/onetl/base/base_file_format.py index a4c72e3e5..9e63fda57 100644 --- a/onetl/base/base_file_format.py +++ b/onetl/base/base_file_format.py @@ -3,7 +3,7 @@ from __future__ import annotations from abc import ABC, abstractmethod -from typing import TYPE_CHECKING, ContextManager +from typing import TYPE_CHECKING if TYPE_CHECKING: from pyspark.sql import DataFrameReader, DataFrameWriter, SparkSession @@ -30,7 +30,7 @@ def check_if_supported(self, spark: SparkSession) -> None: """ @abstractmethod - def apply_to_reader(self, reader: DataFrameReader) -> DataFrameReader | ContextManager[DataFrameReader]: + def apply_to_reader(self, reader: DataFrameReader) -> DataFrameReader: """ Apply provided format to :obj:`pyspark.sql.DataFrameReader`. |support_hooks| @@ -40,10 +40,6 @@ def apply_to_reader(self, reader: DataFrameReader) -> DataFrameReader | ContextM ------- :obj:`pyspark.sql.DataFrameReader` DataFrameReader with options applied. - - ``ContextManager[DataFrameReader]`` - If returned context manager, it will be entered before reading data and exited after creating a DataFrame. - Context manager's ``__enter__`` method should return :obj:`pyspark.sql.DataFrameReader` instance. """ @@ -68,7 +64,7 @@ def check_if_supported(self, spark: SparkSession) -> None: """ @abstractmethod - def apply_to_writer(self, writer: DataFrameWriter) -> DataFrameWriter | ContextManager[DataFrameWriter]: + def apply_to_writer(self, writer: DataFrameWriter) -> DataFrameWriter: """ Apply provided format to :obj:`pyspark.sql.DataFrameWriter`. |support_hooks| @@ -78,8 +74,4 @@ def apply_to_writer(self, writer: DataFrameWriter) -> DataFrameWriter | ContextM ------- :obj:`pyspark.sql.DataFrameWriter` DataFrameWriter with options applied. - - ``ContextManager[DataFrameWriter]`` - If returned context manager, it will be entered before writing and exited after writing a DataFrame. - Context manager's ``__enter__`` method should return :obj:`pyspark.sql.DataFrameWriter` instance. """ diff --git a/onetl/connection/db_connection/greenplum/connection.py b/onetl/connection/db_connection/greenplum/connection.py index 7ed60539b..4e26c86ef 100644 --- a/onetl/connection/db_connection/greenplum/connection.py +++ b/onetl/connection/db_connection/greenplum/connection.py @@ -5,17 +5,18 @@ import logging import os import textwrap +import threading import warnings -from typing import TYPE_CHECKING, Any, ClassVar +from typing import TYPE_CHECKING, Any, ClassVar, Optional from etl_entities.instance import Host from onetl.connection.db_connection.jdbc_connection.options import JDBCReadOptions try: - from pydantic.v1 import validator + from pydantic.v1 import PrivateAttr, SecretStr, validator except (ImportError, AttributeError): - from pydantic import validator # type: ignore[no-redef, assignment] + from pydantic import validator, SecretStr, PrivateAttr # type: ignore[no-redef, assignment] from onetl._util.classproperty import classproperty from onetl._util.java import try_import_java_class @@ -39,7 +40,9 @@ from onetl.connection.db_connection.jdbc_mixin.options import ( JDBCExecuteOptions, JDBCFetchOptions, - JDBCOptions, +) +from onetl.connection.db_connection.jdbc_mixin.options import ( + JDBCOptions as JDBCMixinOptions, ) from onetl.exception import MISSING_JVM_CLASS_MSG, TooManyParallelJobsError from onetl.hooks import slot, support_hooks @@ -69,11 +72,11 @@ class GreenplumExtra(GenericOptions): class Config: extra = "allow" - prohibited_options = JDBCOptions.Config.prohibited_options + prohibited_options = JDBCMixinOptions.Config.prohibited_options @support_hooks -class Greenplum(JDBCMixin, DBConnection): +class Greenplum(JDBCMixin, DBConnection): # noqa: WPS338 """Greenplum connection. |support_hooks| Based on package ``io.pivotal:greenplum-spark:2.2.0`` @@ -157,6 +160,8 @@ class Greenplum(JDBCMixin, DBConnection): """ host: Host + user: str + password: SecretStr database: str port: int = 5432 extra: GreenplumExtra = GreenplumExtra() @@ -166,6 +171,7 @@ class Greenplum(JDBCMixin, DBConnection): SQLOptions = GreenplumSQLOptions FetchOptions = GreenplumFetchOptions ExecuteOptions = GreenplumExecuteOptions + JDBCOptions = JDBCMixinOptions Extra = GreenplumExtra Dialect = GreenplumDialect @@ -174,6 +180,9 @@ class Greenplum(JDBCMixin, DBConnection): CONNECTIONS_WARNING_LIMIT: ClassVar[int] = 31 CONNECTIONS_EXCEPTION_LIMIT: ClassVar[int] = 100 + _CHECK_QUERY: ClassVar[str] = "SELECT 1" + _last_connection_and_options: Optional[threading.local] = PrivateAttr(default=None) + @slot @classmethod def get_packages( diff --git a/onetl/connection/db_connection/hive/connection.py b/onetl/connection/db_connection/hive/connection.py index 81c50e87e..696466869 100644 --- a/onetl/connection/db_connection/hive/connection.py +++ b/onetl/connection/db_connection/hive/connection.py @@ -15,6 +15,7 @@ from onetl._util.spark import inject_spark_param from onetl._util.sql import clear_statement +from onetl.base import BaseWritableFileFormat from onetl.connection.db_connection.db_connection import DBConnection from onetl.connection.db_connection.hive.dialect import HiveDialect from onetl.connection.db_connection.hive.options import ( @@ -23,7 +24,7 @@ HiveWriteOptions, ) from onetl.connection.db_connection.hive.slots import HiveSlots -from onetl.file.format.file_format import WriteOnlyFileFormat +from onetl.file.format.file_format import ReadWriteFileFormat, WriteOnlyFileFormat from onetl.hooks import slot, support_hooks from onetl.hwm import Window from onetl.log import log_lines, log_with_indent @@ -456,7 +457,7 @@ def _format_write_options(self, write_options: HiveWriteOptions) -> dict: exclude={"if_exists"}, ) - if isinstance(write_options.format, WriteOnlyFileFormat): + if isinstance(write_options.format, (WriteOnlyFileFormat, ReadWriteFileFormat)): options_dict["format"] = write_options.format.name options_dict.update(write_options.format.dict(exclude={"name"})) @@ -485,7 +486,7 @@ def _save_as_table( writer = writer.option(method, value) # deserialize passed OCR(), Parquet(), CSV(), etc. file formats - if isinstance(write_options.format, WriteOnlyFileFormat): + if isinstance(write_options.format, BaseWritableFileFormat): writer = write_options.format.apply_to_writer(writer) elif isinstance(write_options.format, str): writer = writer.format(write_options.format) diff --git a/onetl/connection/db_connection/hive/options.py b/onetl/connection/db_connection/hive/options.py index 16d21a0e7..9eb558bad 100644 --- a/onetl/connection/db_connection/hive/options.py +++ b/onetl/connection/db_connection/hive/options.py @@ -13,7 +13,7 @@ from typing_extensions import deprecated -from onetl.file.format.file_format import WriteOnlyFileFormat +from onetl.base import BaseWritableFileFormat from onetl.impl import GenericOptions @@ -199,7 +199,7 @@ class Config: does not affect behavior. """ - format: Union[str, WriteOnlyFileFormat] = "orc" + format: Union[str, BaseWritableFileFormat] = "orc" """Format of files which should be used for storing table data. Examples diff --git a/onetl/connection/db_connection/jdbc_connection/connection.py b/onetl/connection/db_connection/jdbc_connection/connection.py index 5b0aebeb8..8ad2be898 100644 --- a/onetl/connection/db_connection/jdbc_connection/connection.py +++ b/onetl/connection/db_connection/jdbc_connection/connection.py @@ -4,9 +4,16 @@ import logging import secrets +import threading import warnings -from typing import TYPE_CHECKING, Any +from typing import TYPE_CHECKING, Any, ClassVar, Optional +try: + from pydantic.v1 import PrivateAttr, SecretStr, validator +except (ImportError, AttributeError): + from pydantic import PrivateAttr, SecretStr, validator # type: ignore[no-redef, assignment] + +from onetl._util.java import try_import_java_class from onetl._util.sql import clear_statement from onetl.connection.db_connection.db_connection import DBConnection from onetl.connection.db_connection.jdbc_connection.dialect import JDBCDialect @@ -19,7 +26,14 @@ JDBCWriteOptions, ) from onetl.connection.db_connection.jdbc_mixin import JDBCMixin -from onetl.connection.db_connection.jdbc_mixin.options import JDBCFetchOptions +from onetl.connection.db_connection.jdbc_mixin.options import ( + JDBCExecuteOptions, + JDBCFetchOptions, +) +from onetl.connection.db_connection.jdbc_mixin.options import ( + JDBCOptions as JDBCMixinOptions, +) +from onetl.exception import MISSING_JVM_CLASS_MSG from onetl.hooks import slot, support_hooks from onetl.hwm import Window from onetl.log import log_lines, log_with_indent @@ -44,13 +58,38 @@ @support_hooks -class JDBCConnection(JDBCMixin, DBConnection): +class JDBCConnection(JDBCMixin, DBConnection): # noqa: WPS338 + user: str + password: SecretStr + + DRIVER: ClassVar[str] + _CHECK_QUERY: ClassVar[str] = "SELECT 1" + _last_connection_and_options: Optional[threading.local] = PrivateAttr(default=None) + + JDBCOptions = JDBCMixinOptions + FetchOptions = JDBCFetchOptions + ExecuteOptions = JDBCExecuteOptions Dialect = JDBCDialect ReadOptions = JDBCReadOptions SQLOptions = JDBCSQLOptions WriteOptions = JDBCWriteOptions Options = JDBCLegacyOptions + @validator("spark") + def _check_java_class_imported(cls, spark): + try: + try_import_java_class(spark, cls.DRIVER) + except Exception as e: + msg = MISSING_JVM_CLASS_MSG.format( + java_class=cls.DRIVER, + package_source=cls.__name__, + args="", + ) + if log.isEnabledFor(logging.DEBUG): + log.debug("Missing Java class", exc_info=e, stack_info=True) + raise ValueError(msg) from e + return spark + @slot def sql( self, @@ -109,11 +148,16 @@ def read_source_as_df( limit: int | None = None, options: JDBCReadOptions | None = None, ) -> DataFrame: + if isinstance(options, JDBCLegacyOptions): + raw_options = self.ReadOptions.parse(options.dict(exclude_unset=True)) + else: + raw_options = self.ReadOptions.parse(options) + read_options = self._set_lower_upper_bound( table=source, where=where, hint=hint, - options=self.ReadOptions.parse(options), + options=raw_options, ) new_columns = columns or ["*"] @@ -170,7 +214,11 @@ def write_df_to_target( target: str, options: JDBCWriteOptions | None = None, ) -> None: - write_options = self.WriteOptions.parse(options) + if isinstance(options, JDBCLegacyOptions): + write_options = self.WriteOptions.parse(options.dict(exclude_unset=True)) + else: + write_options = self.WriteOptions.parse(options) + jdbc_properties = self._get_jdbc_properties(write_options, exclude={"if_exists"}, exclude_none=True) mode = ( diff --git a/onetl/connection/db_connection/jdbc_connection/options.py b/onetl/connection/db_connection/jdbc_connection/options.py index a2aa39adb..7e189e86f 100644 --- a/onetl/connection/db_connection/jdbc_connection/options.py +++ b/onetl/connection/db_connection/jdbc_connection/options.py @@ -672,7 +672,19 @@ def _check_partition_fields(cls, values): "Deprecated in 0.5.0 and will be removed in 1.0.0. Use 'ReadOptions' or 'WriteOptions' instead", category=UserWarning, ) -class JDBCLegacyOptions(JDBCReadOptions, JDBCWriteOptions): +class JDBCLegacyOptions(GenericOptions): class Config: prohibited_options = GENERIC_PROHIBITED_OPTIONS + known_options = READ_OPTIONS | WRITE_OPTIONS | READ_WRITE_OPTIONS extra = "allow" + + partition_column: Optional[str] = Field(default=None, alias="partitionColumn") + num_partitions: PositiveInt = Field(default=1, alias="numPartitions") + lower_bound: Optional[int] = Field(default=None, alias="lowerBound") + upper_bound: Optional[int] = Field(default=None, alias="upperBound") + session_init_statement: Optional[str] = Field(default=None, alias="sessionInitStatement") + query_timeout: Optional[int] = Field(default=None, alias="queryTimeout") + if_exists: JDBCTableExistBehavior = Field(default=JDBCTableExistBehavior.APPEND, alias="mode") + isolation_level: str = Field(default="READ_UNCOMMITTED", alias="isolationLevel") + fetchsize: int = 100_000 + batchsize: int = 20_000 diff --git a/onetl/connection/db_connection/jdbc_mixin/connection.py b/onetl/connection/db_connection/jdbc_mixin/connection.py index e8c19e38b..4b9c3b2ce 100644 --- a/onetl/connection/db_connection/jdbc_mixin/connection.py +++ b/onetl/connection/db_connection/jdbc_mixin/connection.py @@ -12,11 +12,11 @@ from onetl.impl.generic_options import GenericOptions try: - from pydantic.v1 import Field, PrivateAttr, SecretStr, validator + from pydantic.v1 import Field, PrivateAttr, SecretStr except (ImportError, AttributeError): - from pydantic import Field, PrivateAttr, SecretStr, validator # type: ignore[no-redef, assignment] + from pydantic import Field, PrivateAttr, SecretStr # type: ignore[no-redef, assignment] -from onetl._util.java import get_java_gateway, try_import_java_class +from onetl._util.java import get_java_gateway from onetl._util.spark import get_spark_version, stringify from onetl._util.sql import clear_statement from onetl._util.version import Version @@ -27,9 +27,7 @@ from onetl.connection.db_connection.jdbc_mixin.options import ( JDBCOptions as JDBCMixinOptions, ) -from onetl.exception import MISSING_JVM_CLASS_MSG from onetl.hooks import slot, support_hooks -from onetl.impl import FrozenModel from onetl.log import log_lines if TYPE_CHECKING: @@ -57,7 +55,7 @@ class JDBCStatementType(Enum): @support_hooks -class JDBCMixin(FrozenModel): +class JDBCMixin: """ Compatibility layer between Python and Java SQL Module. @@ -286,21 +284,6 @@ def execute( log.info("|%s| Execution succeeded, nothing returned", self.__class__.__name__) return df - @validator("spark") - def _check_java_class_imported(cls, spark): - try: - try_import_java_class(spark, cls.DRIVER) - except Exception as e: - msg = MISSING_JVM_CLASS_MSG.format( - java_class=cls.DRIVER, - package_source=cls.__name__, - args="", - ) - if log.isEnabledFor(logging.DEBUG): - log.debug("Missing Java class", exc_info=e, stack_info=True) - raise ValueError(msg) from e - return spark - def _query_on_driver( self, query: str, diff --git a/onetl/connection/file_df_connection/spark_file_df_connection.py b/onetl/connection/file_df_connection/spark_file_df_connection.py index 06121139f..892cd8efd 100644 --- a/onetl/connection/file_df_connection/spark_file_df_connection.py +++ b/onetl/connection/file_df_connection/spark_file_df_connection.py @@ -76,11 +76,7 @@ def read_files_as_df( reader: DataFrameReader = self.spark.read with ExitStack() as stack: - format_result = format.apply_to_reader(reader) - if isinstance(format_result, AbstractContextManager): - reader = stack.enter_context(format_result) - else: - reader = format_result + reader = format.apply_to_reader(reader) if root: reader = reader.option("basePath", self._convert_to_url(root)) @@ -111,12 +107,7 @@ def write_df_as_files( writer: DataFrameWriter = df.write with ExitStack() as stack: - format_result = format.apply_to_writer(writer) - - if isinstance(format_result, AbstractContextManager): - writer = stack.enter_context(format_result) - else: - writer = format_result + writer = format.apply_to_writer(writer) if options: options_result = options.apply_to_writer(writer) diff --git a/onetl/file/format/file_format.py b/onetl/file/format/file_format.py index 4a7de4f3e..3b8a8fc5f 100644 --- a/onetl/file/format/file_format.py +++ b/onetl/file/format/file_format.py @@ -51,5 +51,19 @@ def apply_to_writer(self, writer: DataFrameWriter) -> DataFrameWriter: return writer.format(self.name).options(**options) -class ReadWriteFileFormat(ReadOnlyFileFormat, WriteOnlyFileFormat): - pass +@support_hooks +class ReadWriteFileFormat(BaseReadableFileFormat, BaseWritableFileFormat, GenericOptions): + name: ClassVar[str] + + class Config: + prohibited_options = PROHIBITED_OPTIONS + + @slot + def apply_to_reader(self, reader: DataFrameReader) -> DataFrameReader: + options = self.dict(by_alias=True) + return reader.format(self.name).options(**options) + + @slot + def apply_to_writer(self, writer: DataFrameWriter) -> DataFrameWriter: + options = self.dict(by_alias=True) + return writer.format(self.name).options(**options) diff --git a/onetl/strategy/incremental_strategy.py b/onetl/strategy/incremental_strategy.py index 0397514b6..613d22c88 100644 --- a/onetl/strategy/incremental_strategy.py +++ b/onetl/strategy/incremental_strategy.py @@ -6,23 +6,11 @@ from etl_entities.hwm import HWM -from onetl.impl import BaseModel from onetl.strategy.batch_hwm_strategy import BatchHWMStrategy from onetl.strategy.hwm_strategy import HWMStrategy -class OffsetMixin(BaseModel): - hwm: Optional[HWM] = None - offset: Any = None - - def fetch_hwm(self) -> None: - super().fetch_hwm() - - if self.hwm and self.hwm.value is not None and self.offset is not None: - self.hwm -= self.offset - - -class IncrementalStrategy(OffsetMixin, HWMStrategy): +class IncrementalStrategy(HWMStrategy): """Incremental strategy for :ref:`db-reader`/:ref:`file-downloader`. Used for fetching only new rows/files from a source @@ -353,8 +341,17 @@ class IncrementalStrategy(OffsetMixin, HWMStrategy): # current run will download only files which were not downloaded in previous runs """ + hwm: Optional[HWM] = None + offset: Any = None + + def fetch_hwm(self) -> None: + super().fetch_hwm() + + if self.hwm and self.hwm.value is not None and self.offset is not None: + self.hwm -= self.offset -class IncrementalBatchStrategy(OffsetMixin, BatchHWMStrategy): + +class IncrementalBatchStrategy(BatchHWMStrategy): """Incremental batch strategy for :ref:`db-reader`. .. note:: @@ -669,6 +666,15 @@ class IncrementalBatchStrategy(OffsetMixin, BatchHWMStrategy): """ + hwm: Optional[HWM] = None + offset: Any = None + + def fetch_hwm(self) -> None: + super().fetch_hwm() + + if self.hwm and self.hwm.value is not None and self.offset is not None: + self.hwm -= self.offset + def __next__(self): self.save_hwm() return super().__next__()