diff --git a/.github/workflows/data/core/matrix.yml b/.github/workflows/data/core/matrix.yml index d20f074ab..6ce0d7a8e 100644 --- a/.github/workflows/data/core/matrix.yml +++ b/.github/workflows/data/core/matrix.yml @@ -22,4 +22,4 @@ latest: &latest matrix: small: [*max] full: [*min, *max] - nightly: [*min, *max, *latest] + nightly: [*min, *latest] diff --git a/onetl/_util/scala.py b/onetl/_util/scala.py index 397a91576..ba80c07f1 100644 --- a/onetl/_util/scala.py +++ b/onetl/_util/scala.py @@ -11,4 +11,6 @@ def get_default_scala_version(spark_version: Version) -> Version: """ if spark_version.major < 3: return Version("2.11") - return Version("2.12") + if spark_version.major < 4: + return Version("2.12") + return Version("2.13") diff --git a/onetl/connection/db_connection/jdbc_mixin/connection.py b/onetl/connection/db_connection/jdbc_mixin/connection.py index e8c19e38b..0ff6b5245 100644 --- a/onetl/connection/db_connection/jdbc_mixin/connection.py +++ b/onetl/connection/db_connection/jdbc_mixin/connection.py @@ -431,10 +431,11 @@ def _execute_on_driver( statement_args = self._get_statement_args() jdbc_statement = self._build_statement(statement, statement_type, jdbc_connection, statement_args) - return self._execute_statement(jdbc_statement, statement, options, callback, read_only) + return self._execute_statement(jdbc_connection, jdbc_statement, statement, options, callback, read_only) def _execute_statement( self, + jdbc_connection, jdbc_statement, statement: str, options: JDBCFetchOptions | JDBCExecuteOptions, @@ -472,7 +473,7 @@ def _execute_statement( else: jdbc_statement.executeUpdate(statement) - return callback(jdbc_statement) + return callback(jdbc_connection, jdbc_statement) @staticmethod def _build_statement( @@ -501,11 +502,11 @@ def _build_statement( return jdbc_connection.createStatement(*statement_args) - def _statement_to_dataframe(self, jdbc_statement) -> DataFrame: + def _statement_to_dataframe(self, jdbc_connection, jdbc_statement) -> DataFrame: result_set = jdbc_statement.getResultSet() - return self._resultset_to_dataframe(result_set) + return self._resultset_to_dataframe(jdbc_connection, result_set) - def _statement_to_optional_dataframe(self, jdbc_statement) -> DataFrame | None: + def _statement_to_optional_dataframe(self, jdbc_connection, jdbc_statement) -> DataFrame | None: """ Returns ``org.apache.spark.sql.DataFrame`` or ``None``, if ResultSet is does not contain any columns. @@ -522,9 +523,9 @@ def _statement_to_optional_dataframe(self, jdbc_statement) -> DataFrame | None: if not result_column_count: return None - return self._resultset_to_dataframe(result_set) + return self._resultset_to_dataframe(jdbc_connection, result_set) - def _resultset_to_dataframe(self, result_set) -> DataFrame: + def _resultset_to_dataframe(self, jdbc_connection, result_set) -> DataFrame: """ Converts ``java.sql.ResultSet`` to ``org.apache.spark.sql.DataFrame`` using Spark's internal methods. @@ -545,13 +546,25 @@ def _resultset_to_dataframe(self, result_set) -> DataFrame: java_converters = self.spark._jvm.scala.collection.JavaConverters # type: ignore - if get_spark_version(self.spark) >= Version("3.4"): + if get_spark_version(self.spark) >= Version("4.0"): + result_schema = jdbc_utils.getSchema( + jdbc_connection, + result_set, + jdbc_dialect, + False, # noqa: WPS425 + False, # noqa: WPS425 + ) + elif get_spark_version(self.spark) >= Version("3.4"): # https://github.com/apache/spark/commit/2349175e1b81b0a61e1ed90c2d051c01cf78de9b result_schema = jdbc_utils.getSchema(result_set, jdbc_dialect, False, False) # noqa: WPS425 else: result_schema = jdbc_utils.getSchema(result_set, jdbc_dialect, False) # noqa: WPS425 - result_iterator = jdbc_utils.resultSetToRows(result_set, result_schema) + if get_spark_version(self.spark) >= Version("4.0"): + result_iterator = jdbc_utils.resultSetToRows(result_set, result_schema, jdbc_dialect) + else: + result_iterator = jdbc_utils.resultSetToRows(result_set, result_schema) + result_list = java_converters.seqAsJavaListConverter(result_iterator.toSeq()).asJava() jdf = self.spark._jsparkSession.createDataFrame(result_list, result_schema) # type: ignore diff --git a/onetl/connection/db_connection/kafka/connection.py b/onetl/connection/db_connection/kafka/connection.py index ce3829e49..6d8b377c5 100644 --- a/onetl/connection/db_connection/kafka/connection.py +++ b/onetl/connection/db_connection/kafka/connection.py @@ -432,8 +432,13 @@ def get_packages( raise ValueError(f"Spark version must be at least 2.4, got {spark_ver}") scala_ver = Version(scala_version).min_digits(2) if scala_version else get_default_scala_version(spark_ver) + + if spark_ver.major < 4: + version = spark_ver.format("{0}.{1}.{2}") + else: + version = "4.0.0-preview1" return [ - f"org.apache.spark:spark-sql-kafka-0-10_{scala_ver.format('{0}.{1}')}:{spark_ver.format('{0}.{1}.{2}')}", + f"org.apache.spark:spark-sql-kafka-0-10_{scala_ver.format('{0}.{1}')}:{version}", ] def __enter__(self): diff --git a/onetl/connection/file_df_connection/spark_s3/connection.py b/onetl/connection/file_df_connection/spark_s3/connection.py index 1efe39d4e..8b9e76c36 100644 --- a/onetl/connection/file_df_connection/spark_s3/connection.py +++ b/onetl/connection/file_df_connection/spark_s3/connection.py @@ -246,9 +246,14 @@ def get_packages( # https://issues.apache.org/jira/browse/SPARK-23977 raise ValueError(f"Spark version must be at least 3.x, got {spark_ver}") + if spark_ver.major < 4: + version = spark_ver.format("{0}.{1}.{2}") + else: + version = "4.0.0-preview1" + scala_ver = Version(scala_version).min_digits(2) if scala_version else get_default_scala_version(spark_ver) # https://mvnrepository.com/artifact/org.apache.spark/spark-hadoop-cloud - return [f"org.apache.spark:spark-hadoop-cloud_{scala_ver.format('{0}.{1}')}:{spark_ver.format('{0}.{1}.{2}')}"] + return [f"org.apache.spark:spark-hadoop-cloud_{scala_ver.format('{0}.{1}')}:{version}"] @slot def path_from_string(self, path: os.PathLike | str) -> RemotePath: diff --git a/onetl/file/format/avro.py b/onetl/file/format/avro.py index 3699620b0..0205c2451 100644 --- a/onetl/file/format/avro.py +++ b/onetl/file/format/avro.py @@ -163,7 +163,12 @@ def get_packages( if scala_ver < Version("2.11"): raise ValueError(f"Scala version should be at least 2.11, got {scala_ver.format('{0}.{1}')}") - return [f"org.apache.spark:spark-avro_{scala_ver.format('{0}.{1}')}:{spark_ver.format('{0}.{1}.{2}')}"] + if spark_ver.major < 4: + version = spark_ver.format("{0}.{1}.{2}") + else: + version = "4.0.0-preview1" + + return [f"org.apache.spark:spark-avro_{scala_ver.format('{0}.{1}')}:{version}"] @slot def check_if_supported(self, spark: SparkSession) -> None: diff --git a/onetl/file/format/xml.py b/onetl/file/format/xml.py index cc7cd4777..885e337a2 100644 --- a/onetl/file/format/xml.py +++ b/onetl/file/format/xml.py @@ -193,6 +193,9 @@ def get_packages( # noqa: WPS231 ) """ + spark_ver = Version(spark_version) + if spark_ver.major >= 4: + return [] if package_version: version = Version(package_version).min_digits(3) @@ -202,7 +205,6 @@ def get_packages( # noqa: WPS231 else: version = Version("0.18.0").min_digits(3) - spark_ver = Version(spark_version) scala_ver = Version(scala_version).min_digits(2) if scala_version else get_default_scala_version(spark_ver) # Ensure compatibility with Spark and Scala versions diff --git a/requirements/tests/spark-latest.txt b/requirements/tests/spark-latest.txt index ea93652c9..038d47de7 100644 --- a/requirements/tests/spark-latest.txt +++ b/requirements/tests/spark-latest.txt @@ -1,5 +1,5 @@ numpy>=1.16 pandas>=1.0 pyarrow>=1.0 -pyspark +pyspark==4.0.0.dev1 sqlalchemy diff --git a/tests/fixtures/spark.py b/tests/fixtures/spark.py index e7248e84f..e91eccf21 100644 --- a/tests/fixtures/spark.py +++ b/tests/fixtures/spark.py @@ -103,7 +103,7 @@ def maven_packages(request): # There is no MongoDB connector for Spark less than 3.2 packages.extend(MongoDB.get_packages(spark_version=str(pyspark_version))) - if "excel" in markers: + if "excel" in markers and pyspark_version < Version("4.0"): # There is no Excel files support for Spark less than 3.2 packages.extend(Excel.get_packages(spark_version=str(pyspark_version)))