diff --git a/.github/workflows/data/clickhouse/matrix.yml b/.github/workflows/data/clickhouse/matrix.yml
index d0db924d4..cf52893ab 100644
--- a/.github/workflows/data/clickhouse/matrix.yml
+++ b/.github/workflows/data/clickhouse/matrix.yml
@@ -5,7 +5,7 @@ min: &min
os: ubuntu-latest
max: &max
- spark-version: 3.4.1
+ spark-version: 3.5.0
python-version: '3.11'
java-version: 20
os: ubuntu-latest
diff --git a/.github/workflows/data/core/matrix.yml b/.github/workflows/data/core/matrix.yml
index 042010659..78cb7f316 100644
--- a/.github/workflows/data/core/matrix.yml
+++ b/.github/workflows/data/core/matrix.yml
@@ -5,7 +5,7 @@ min: &min
os: ubuntu-latest
max: &max
- spark-version: 3.4.1
+ spark-version: 3.5.0
python-version: '3.11'
java-version: 20
os: ubuntu-latest
diff --git a/.github/workflows/data/hdfs/matrix.yml b/.github/workflows/data/hdfs/matrix.yml
index 48d72137c..eba19818e 100644
--- a/.github/workflows/data/hdfs/matrix.yml
+++ b/.github/workflows/data/hdfs/matrix.yml
@@ -7,7 +7,7 @@ min: &min
max: &max
hadoop-version: hadoop3-hdfs
- spark-version: 3.4.1
+ spark-version: 3.5.0
python-version: '3.11'
java-version: 20
os: ubuntu-latest
diff --git a/.github/workflows/data/hive/matrix.yml b/.github/workflows/data/hive/matrix.yml
index 66c39e9d4..17c1c3a6c 100644
--- a/.github/workflows/data/hive/matrix.yml
+++ b/.github/workflows/data/hive/matrix.yml
@@ -5,7 +5,7 @@ min: &min
os: ubuntu-latest
max: &max
- spark-version: 3.4.1
+ spark-version: 3.5.0
python-version: '3.11'
java-version: 20
os: ubuntu-latest
diff --git a/.github/workflows/data/kafka/matrix.yml b/.github/workflows/data/kafka/matrix.yml
index 5d614cd9b..7797d37d6 100644
--- a/.github/workflows/data/kafka/matrix.yml
+++ b/.github/workflows/data/kafka/matrix.yml
@@ -8,7 +8,7 @@ min: &min
max: &max
kafka-version: 3.5.1
- spark-version: 3.4.1
+ spark-version: 3.5.0
python-version: '3.11'
java-version: 20
os: ubuntu-latest
diff --git a/.github/workflows/data/local-fs/matrix.yml b/.github/workflows/data/local-fs/matrix.yml
index e956169ba..4329d6582 100644
--- a/.github/workflows/data/local-fs/matrix.yml
+++ b/.github/workflows/data/local-fs/matrix.yml
@@ -16,12 +16,18 @@ min_excel: &min_excel
java-version: 8
os: ubuntu-latest
-max: &max
+max_excel: &max_excel
spark-version: 3.4.1
python-version: '3.11'
java-version: 20
os: ubuntu-latest
+max: &max
+ spark-version: 3.5.0
+ python-version: '3.11'
+ java-version: 20
+ os: ubuntu-latest
+
latest: &latest
spark-version: latest
python-version: '3.11'
@@ -30,13 +36,13 @@ latest: &latest
matrix:
small:
+ - <<: *max_excel
- <<: *max
- - <<: *min_avro
- - <<: *min_excel
full:
- <<: *min
- <<: *min_avro
- <<: *min_excel
+ - <<: *max_excel
- <<: *max
nightly:
- <<: *min
diff --git a/.github/workflows/data/mssql/matrix.yml b/.github/workflows/data/mssql/matrix.yml
index a5576ab65..b9941b583 100644
--- a/.github/workflows/data/mssql/matrix.yml
+++ b/.github/workflows/data/mssql/matrix.yml
@@ -5,7 +5,7 @@ min: &min
os: ubuntu-latest
max: &max
- spark-version: 3.4.1
+ spark-version: 3.5.0
python-version: '3.11'
java-version: 20
os: ubuntu-latest
diff --git a/.github/workflows/data/mysql/matrix.yml b/.github/workflows/data/mysql/matrix.yml
index 9b34a980b..39ba7034f 100644
--- a/.github/workflows/data/mysql/matrix.yml
+++ b/.github/workflows/data/mysql/matrix.yml
@@ -5,7 +5,7 @@ min: &min
os: ubuntu-latest
max: &max
- spark-version: 3.4.1
+ spark-version: 3.5.0
python-version: '3.11'
java-version: 20
os: ubuntu-latest
diff --git a/.github/workflows/data/oracle/matrix.yml b/.github/workflows/data/oracle/matrix.yml
index b51bb54e9..20086bf04 100644
--- a/.github/workflows/data/oracle/matrix.yml
+++ b/.github/workflows/data/oracle/matrix.yml
@@ -5,7 +5,7 @@ min: &min
os: ubuntu-latest
max: &max
- spark-version: 3.4.1
+ spark-version: 3.5.0
python-version: '3.11'
java-version: 20
os: ubuntu-latest
diff --git a/.github/workflows/data/postgres/matrix.yml b/.github/workflows/data/postgres/matrix.yml
index 1d27793f9..c5233c5e8 100644
--- a/.github/workflows/data/postgres/matrix.yml
+++ b/.github/workflows/data/postgres/matrix.yml
@@ -5,7 +5,7 @@ min: &min
os: ubuntu-latest
max: &max
- spark-version: 3.4.1
+ spark-version: 3.5.0
python-version: '3.11'
java-version: 20
os: ubuntu-latest
diff --git a/.github/workflows/data/s3/matrix.yml b/.github/workflows/data/s3/matrix.yml
index 44779fe95..2b6fbdb32 100644
--- a/.github/workflows/data/s3/matrix.yml
+++ b/.github/workflows/data/s3/matrix.yml
@@ -9,7 +9,7 @@ min: &min
max: &max
minio-version: 2023.7.18
- spark-version: 3.4.1
+ spark-version: 3.5.0
python-version: '3.11'
java-version: 20
os: ubuntu-latest
diff --git a/.github/workflows/data/teradata/matrix.yml b/.github/workflows/data/teradata/matrix.yml
index d6f8b6d20..05da497c8 100644
--- a/.github/workflows/data/teradata/matrix.yml
+++ b/.github/workflows/data/teradata/matrix.yml
@@ -1,5 +1,5 @@
max: &max
- spark-version: 3.4.1
+ spark-version: 3.5.0
python-version: '3.11'
java-version: 20
os: ubuntu-latest
diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml
index 193ae3c3d..9f0addf08 100644
--- a/.pre-commit-config.yaml
+++ b/.pre-commit-config.yaml
@@ -1,6 +1,6 @@
repos:
- repo: https://github.com/pre-commit/pre-commit-hooks
- rev: v4.4.0
+ rev: v4.5.0
hooks:
- id: check-ast
- id: check-case-conflict
@@ -28,7 +28,7 @@ repos:
- id: remove-tabs
exclude: ^docs/(make.bat|Makefile)
- repo: https://github.com/codespell-project/codespell
- rev: v2.2.5
+ rev: v2.2.6
hooks:
- id: codespell
args: [-w]
@@ -59,7 +59,7 @@ repos:
- id: rst-inline-touching-normal
- id: text-unicode-replacement-char
- repo: https://github.com/asottile/pyupgrade
- rev: v3.13.0
+ rev: v3.15.0
hooks:
- id: pyupgrade
args: [--py37-plus, --keep-runtime-typing]
diff --git a/README.rst b/README.rst
index 4f8b0aca8..792eb1dc4 100644
--- a/README.rst
+++ b/README.rst
@@ -52,7 +52,7 @@ Non-goals
Requirements
------------
* **Python 3.7 - 3.11**
-* PySpark 2.3.x - 3.4.x (depends on used connector)
+* PySpark 2.3.x - 3.5.x (depends on used connector)
* Java 8+ (required by Spark, see below)
* Kerberos libs & GCC (required by ``Hive``, ``HDFS`` and ``SparkHDFS`` connectors)
@@ -96,7 +96,7 @@ Supported storages
+ +--------------+----------------------------------------------------------------------------------------------------------------------+
| | Samba | `pysmb library `_ |
+--------------------+--------------+----------------------------------------------------------------------------------------------------------------------+
-| Files as DataFrame | SparkLocalFS | Apache Spark `File Data Source `_ |
+| Files as DataFrame | SparkLocalFS | Apache Spark `File Data Source `_ |
| +--------------+ +
| | SparkHDFS | |
| +--------------+----------------------------------------------------------------------------------------------------------------------+
@@ -179,6 +179,8 @@ Compatibility matrix
+--------------------------------------------------------------+-------------+-------------+-------+
| `3.4.x `_ | 3.7 - 3.11 | 8u362 - 20 | 2.12 |
+--------------------------------------------------------------+-------------+-------------+-------+
+| `3.5.x `_ | 3.8 - 3.11 | 8u371 - 20 | 2.12 |
++--------------------------------------------------------------+-------------+-------------+-------+
.. _pyspark-install:
@@ -192,7 +194,7 @@ or install PySpark explicitly:
.. code:: bash
- pip install onetl pyspark==3.4.1 # install a specific PySpark version
+ pip install onetl pyspark==3.5.0 # install a specific PySpark version
or inject PySpark to ``sys.path`` in some other way BEFORE creating a class instance.
**Otherwise connection object cannot be created.**
@@ -530,7 +532,7 @@ Read files directly from S3 path, convert them to dataframe, transform it and th
setup_logging()
# Initialize new SparkSession with Hadoop AWS libraries and Postgres driver loaded
- maven_packages = SparkS3.get_packages(spark_version="3.4.1") + Postgres.get_packages()
+ maven_packages = SparkS3.get_packages(spark_version="3.5.0") + Postgres.get_packages()
spark = (
SparkSession.builder.appName("spark_app_onetl_demo")
.config("spark.jars.packages", ",".join(maven_packages))
diff --git a/docs/changelog/0.9.4.rst b/docs/changelog/0.9.4.rst
index 4eb406ae0..886294486 100644
--- a/docs/changelog/0.9.4.rst
+++ b/docs/changelog/0.9.4.rst
@@ -4,12 +4,12 @@
Features
--------
-- Add ``if_exists="ignore"`` and ``error`` to ``Hive.WriteOptions`` (:github:pull:`143`)
-- Add ``if_exists="ignore"`` and ``error`` to ``JDBC.WriteOptions`` (:github:pull:`144`)
-- Add ``if_exists="ignore"`` and ``error`` to ``MongoDB.WriteOptions`` (:github:pull:`145`)
- Add ``Excel`` file format support. (:github:pull:`148`)
- Add ``Samba`` file connection.
It is now possible to download and upload files to Samba shared folders using ``FileDownloader``/``FileUploader``. (:github:pull:`150`)
+- Add ``if_exists="ignore"`` and ``error`` to ``Hive.WriteOptions`` (:github:pull:`143`)
+- Add ``if_exists="ignore"`` and ``error`` to ``JDBC.WriteOptions`` (:github:pull:`144`)
+- Add ``if_exists="ignore"`` and ``error`` to ``MongoDB.WriteOptions`` (:github:pull:`145`)
Improvements
@@ -21,10 +21,10 @@ Improvements
* Added interaction schemas for reading, writing and executing statements in Greenplum.
* Added recommendations about reading data from views and ``JOIN`` results from Greenplum. (:github:pull:`154`)
- Make ``.fetch`` and ``.execute`` methods of DB connections thread-safe. Each thread works with its own connection. (:github:pull:`156`)
-- Call ``.close()`` on FileConnection then it is removed by garbage collector. (:github:pull:`156`)
+- Call ``.close()`` on ``FileConnection`` then it is removed by garbage collector. (:github:pull:`156`)
Bug Fixes
---------
-- Fix issue while stopping Python interpreter calls ``JDBCMixin.close()`` and prints exceptions to log. (:github:pull:`156`)
+- Fix issue when stopping Python interpreter calls ``JDBCMixin.close()``, but it is finished with exceptions. (:github:pull:`156`)
diff --git a/docs/changelog/0.9.5.rst b/docs/changelog/0.9.5.rst
new file mode 100644
index 000000000..c374f601e
--- /dev/null
+++ b/docs/changelog/0.9.5.rst
@@ -0,0 +1,20 @@
+0.9.5 (2023-10-10)
+==================
+
+Features
+--------
+
+- Add ``XML`` file format support. (:github:pull:`163`)
+- Tested compatibility with Spark 3.5.0. ``MongoDB`` and ``Excel`` are not supported yet, but other packages do. (:github:pull:`159`)
+
+
+Improvements
+------------
+
+- Add check to all DB and FileDF connections that Spark session is alive. (:github:pull:`164`)
+
+
+Bug Fixes
+---------
+
+- Fix ``Hive.check()`` behavior when Hive Metastore is not available. (:github:pull:`164`)
diff --git a/docs/changelog/NEXT_RELEASE.rst b/docs/changelog/NEXT_RELEASE.rst
index ee4196843..a598be017 100644
--- a/docs/changelog/NEXT_RELEASE.rst
+++ b/docs/changelog/NEXT_RELEASE.rst
@@ -1,36 +1,6 @@
-.. copy this file with new release name
-.. then fill it up using towncrier build
-.. and add it to index.rst
+.. fill up this file using ``towncrier build``
+.. then delete everything up to the header with version number
+.. then rename file to ``{VERSION}.rst`` and add it to index.rst
+.. and restore ``NEXT_RELEASE.rst`` content`` as it was before running the command above
.. towncrier release notes start
-
-0.9.4 (2023-09-26)
-==================
-
-Features
---------
-
-- Add ``if_exists="ignore"`` and ``error`` to ``Hive.WriteOptions`` (:github:pull:`143`)
-- Add ``if_exists="ignore"`` and ``error`` to ``JDBC.WriteOptions`` (:github:pull:`144`)
-- Add ``if_exists="ignore"`` and ``error`` to ``MongoDB.WriteOptions`` (:github:pull:`145`)
-- Add ``Excel`` file format support. (:github:pull:`148`)
-- Add ``Samba`` file connection.
- It is now possible to download and upload files to Samba shared folders using ``FileDownloader``/``FileUploader``. (:github:pull:`150`)
-
-
-Improvements
-------------
-
-- Add documentation about different ways of passing packages to Spark session. (:github:pull:`151`)
-- Drastically improve ``Greenplum`` documentation:
- * Added information about network ports, grants, ``pg_hba.conf`` and so on.
- * Added interaction schemas for reading, writing and executing statements in Greenplum.
- * Added recommendations about reading data from views and ``JOIN`` results from Greenplum. (:github:pull:`154`)
-- Make ``.fetch`` and ``.execute`` methods of DB connections thread-safe. Each thread works with its own connection. (:github:pull:`156`)
-- Call ``.close()`` on FileConnection then it is removed by garbage collector. (:github:pull:`156`)
-
-
-Bug Fixes
----------
-
-- Fix issue while stopping Python interpreter calls ``JDBCMixin.close()`` and prints exceptions to log. (:github:pull:`156`)
diff --git a/docs/changelog/index.rst b/docs/changelog/index.rst
index 6130bfdc8..276dd3cf6 100644
--- a/docs/changelog/index.rst
+++ b/docs/changelog/index.rst
@@ -4,6 +4,7 @@
DRAFT
NEXT_RELEASE
+ 0.9.5
0.9.4
0.9.3
0.9.2
diff --git a/docs/concepts.rst b/docs/concepts.rst
index f2f42a539..e4dc8facb 100644
--- a/docs/concepts.rst
+++ b/docs/concepts.rst
@@ -31,7 +31,120 @@ All connection types are inherited from the parent class ``BaseConnection``.
Class diagram
-------------
-.. image:: static/connections.svg
+.. plantuml::
+
+ @startuml
+ left to right direction
+ skinparam classFontSize 20
+ skinparam class {
+ BackgroundColor<> LightGreen
+ BackgroundColor<> Khaki
+ BackgroundColor<> LightBlue
+ StereotypeFontColor<> Transparent
+ StereotypeFontColor<> Transparent
+ StereotypeFontColor<> Transparent
+ }
+
+ class BaseConnection {
+ }
+
+ class DBConnection <>{
+ }
+ DBConnection --|> BaseConnection
+
+ class Hive <>{
+ }
+ Hive --|> DBConnection
+
+ class Greenplum <>{
+ }
+ Greenplum --|> DBConnection
+
+ class MongoDB <>{
+ }
+ MongoDB --|> DBConnection
+
+ class Kafka <>{
+ }
+ Kafka --|> DBConnection
+
+ class JDBCConnection <>{
+ }
+ JDBCConnection --|> DBConnection
+
+ class Clickhouse <>{
+ }
+ Clickhouse --|> JDBCConnection
+
+ class MSSQL <>{
+ }
+ MSSQL --|> JDBCConnection
+
+ class MySQL <>{
+ }
+ MySQL --|> JDBCConnection
+
+ class Postgres <>{
+ }
+ Postgres --|> JDBCConnection
+
+ class Oracle <>{
+ }
+ Oracle --|> JDBCConnection
+
+ class Teradata <>{
+ }
+ Teradata --|> JDBCConnection
+
+ class FileConnection <>{
+ }
+ FileConnection --|> BaseConnection
+
+ class FTP <>{
+ }
+ FTP --|> FileConnection
+
+ class FTPS <>{
+ }
+ FTPS --|> FileConnection
+
+ class HDFS <>{
+ }
+ HDFS --|> FileConnection
+
+ class WebDAV <>{
+ }
+ WebDAV --|> FileConnection
+
+ class Samba <>{
+ }
+ Samba --|> FileConnection
+
+ class SFTP <>{
+ }
+ SFTP --|> FileConnection
+
+ class S3 <>{
+ }
+ S3 --|> FileConnection
+
+ class FileDFConnection <>{
+ }
+ FileDFConnection --|> BaseConnection
+
+ class SparkHDFS <>{
+ }
+ SparkHDFS --|> FileDFConnection
+
+ class SparkLocalFS <>{
+ }
+ SparkLocalFS --|> FileDFConnection
+
+ class SparkS3 <>{
+ }
+ SparkS3 --|> FileDFConnection
+
+ @enduml
DBConnection
------------
diff --git a/docs/file_df/file_formats/index.rst b/docs/file_df/file_formats/index.rst
index 3a39bc061..abaeee2d2 100644
--- a/docs/file_df/file_formats/index.rst
+++ b/docs/file_df/file_formats/index.rst
@@ -14,6 +14,7 @@ File Formats
jsonline
orc
parquet
+ xml
.. toctree::
:maxdepth: 1
diff --git a/docs/file_df/file_formats/xml.rst b/docs/file_df/file_formats/xml.rst
new file mode 100644
index 000000000..187aa89a4
--- /dev/null
+++ b/docs/file_df/file_formats/xml.rst
@@ -0,0 +1,9 @@
+.. _xml-file-format:
+
+XML
+=====
+
+.. currentmodule:: onetl.file.format.xml
+
+.. autoclass:: XML
+ :members: get_packages
diff --git a/docs/static/connections.svg b/docs/static/connections.svg
deleted file mode 100644
index 7c7d11c9e..000000000
--- a/docs/static/connections.svg
+++ /dev/null
@@ -1,975 +0,0 @@
-
-
-
-
diff --git a/onetl/VERSION b/onetl/VERSION
index a602fc9e2..b0bb87854 100644
--- a/onetl/VERSION
+++ b/onetl/VERSION
@@ -1 +1 @@
-0.9.4
+0.9.5
diff --git a/onetl/connection/db_connection/clickhouse/connection.py b/onetl/connection/db_connection/clickhouse/connection.py
index f95884f7d..2ef521072 100644
--- a/onetl/connection/db_connection/clickhouse/connection.py
+++ b/onetl/connection/db_connection/clickhouse/connection.py
@@ -46,7 +46,7 @@ class Clickhouse(JDBCConnection):
.. dropdown:: Version compatibility
* Clickhouse server versions: 20.7 or higher
- * Spark versions: 2.3.x - 3.4.x
+ * Spark versions: 2.3.x - 3.5.x
* Java versions: 8 - 20
See `official documentation `_.
@@ -63,7 +63,7 @@ class Clickhouse(JDBCConnection):
pip install onetl[spark] # latest PySpark version
# or
- pip install onetl pyspark=3.4.1 # pass specific PySpark version
+ pip install onetl pyspark=3.5.0 # pass specific PySpark version
See :ref:`install-spark` installation instruction for more details.
diff --git a/onetl/connection/db_connection/db_connection/connection.py b/onetl/connection/db_connection/db_connection/connection.py
index 315f5b17c..731ef872d 100644
--- a/onetl/connection/db_connection/db_connection/connection.py
+++ b/onetl/connection/db_connection/db_connection/connection.py
@@ -17,7 +17,7 @@
from logging import getLogger
from typing import TYPE_CHECKING
-from pydantic import Field
+from pydantic import Field, validator
from onetl._util.spark import try_import_pyspark
from onetl.base import BaseDBConnection
@@ -48,6 +48,19 @@ def _forward_refs(cls) -> dict[str, type]:
refs["SparkSession"] = SparkSession
return refs
+ @validator("spark")
+ def _check_spark_session_alive(cls, spark):
+ # https://stackoverflow.com/a/36044685
+ msg = "Spark session is stopped. Please recreate Spark session."
+ try:
+ if not spark._jsc.sc().isStopped():
+ return spark
+ except Exception as e:
+ # None has no attribute "something"
+ raise ValueError(msg) from e
+
+ raise ValueError(msg)
+
def _log_parameters(self):
log.info("|%s| Using connection parameters:", self.__class__.__name__)
parameters = self.dict(exclude_none=True, exclude={"spark"})
diff --git a/onetl/connection/db_connection/hive/connection.py b/onetl/connection/db_connection/hive/connection.py
index 6d768ea2e..97cc034f5 100644
--- a/onetl/connection/db_connection/hive/connection.py
+++ b/onetl/connection/db_connection/hive/connection.py
@@ -52,7 +52,7 @@ class Hive(DBConnection):
.. dropdown:: Version compatibility
* Hive metastore version: 0.12 - 3.1.2 (may require to add proper .jar file explicitly)
- * Spark versions: 2.3.x - 3.4.x
+ * Spark versions: 2.3.x - 3.5.x
* Java versions: 8 - 20
.. warning::
@@ -67,7 +67,7 @@ class Hive(DBConnection):
pip install onetl[spark] # latest PySpark version
# or
- pip install onetl pyspark=3.4.1 # pass specific PySpark version
+ pip install onetl pyspark=3.5.0 # pass specific PySpark version
See :ref:`install-spark` installation instruction for more details.
@@ -146,7 +146,7 @@ class Hive(DBConnection):
# TODO: remove in v1.0.0
slots = HiveSlots
- _CHECK_QUERY: ClassVar[str] = "SELECT 1"
+ _CHECK_QUERY: ClassVar[str] = "SHOW DATABASES"
@slot
@classmethod
@@ -207,7 +207,7 @@ def check(self):
log_lines(log, self._CHECK_QUERY, level=logging.DEBUG)
try:
- self._execute_sql(self._CHECK_QUERY)
+ self._execute_sql(self._CHECK_QUERY).limit(1).collect()
log.info("|%s| Connection is available.", self.__class__.__name__)
except Exception as e:
log.exception("|%s| Connection is unavailable", self.__class__.__name__)
diff --git a/onetl/connection/db_connection/kafka/connection.py b/onetl/connection/db_connection/kafka/connection.py
index 51053df0c..111b61b08 100644
--- a/onetl/connection/db_connection/kafka/connection.py
+++ b/onetl/connection/db_connection/kafka/connection.py
@@ -71,7 +71,7 @@ class Kafka(DBConnection):
.. dropdown:: Version compatibility
* Apache Kafka versions: 0.10 or higher
- * Spark versions: 2.4.x - 3.4.x
+ * Spark versions: 2.4.x - 3.5.x
* Scala versions: 2.11 - 2.13
Parameters
@@ -317,7 +317,7 @@ def write_df_to_target(
write_options.update(options.dict(by_alias=True, exclude_none=True, exclude={"if_exists"}))
write_options["topic"] = target
- # As of Apache Spark version 3.4.1, the mode 'error' is not functioning as expected.
+ # As of Apache Spark version 3.5.0, the mode 'error' is not functioning as expected.
# This issue has been reported and can be tracked at:
# https://issues.apache.org/jira/browse/SPARK-44774
mode = options.if_exists
diff --git a/onetl/connection/db_connection/mongodb/connection.py b/onetl/connection/db_connection/mongodb/connection.py
index 860f7b215..280596d5d 100644
--- a/onetl/connection/db_connection/mongodb/connection.py
+++ b/onetl/connection/db_connection/mongodb/connection.py
@@ -507,6 +507,7 @@ def write_df_to_target(
)
if self._collection_exists(target):
+ # MongoDB connector does not support mode=ignore and mode=error
if write_options.if_exists == MongoDBCollectionExistBehavior.ERROR:
raise ValueError("Operation stopped due to MongoDB.WriteOptions(if_exists='error')")
elif write_options.if_exists == MongoDBCollectionExistBehavior.IGNORE:
diff --git a/onetl/connection/db_connection/mssql/connection.py b/onetl/connection/db_connection/mssql/connection.py
index 6738c2541..b4ab8427b 100644
--- a/onetl/connection/db_connection/mssql/connection.py
+++ b/onetl/connection/db_connection/mssql/connection.py
@@ -44,7 +44,7 @@ class MSSQL(JDBCConnection):
.. dropdown:: Version compatibility
* SQL Server versions: 2014 - 2022
- * Spark versions: 2.3.x - 3.4.x
+ * Spark versions: 2.3.x - 3.5.x
* Java versions: 8 - 20
See `official documentation `_
@@ -62,7 +62,7 @@ class MSSQL(JDBCConnection):
pip install onetl[spark] # latest PySpark version
# or
- pip install onetl pyspark=3.4.1 # pass specific PySpark version
+ pip install onetl pyspark=3.5.0 # pass specific PySpark version
See :ref:`install-spark` installation instruction for more details.
diff --git a/onetl/connection/db_connection/mysql/connection.py b/onetl/connection/db_connection/mysql/connection.py
index abd17df33..326ae7a81 100644
--- a/onetl/connection/db_connection/mysql/connection.py
+++ b/onetl/connection/db_connection/mysql/connection.py
@@ -44,7 +44,7 @@ class MySQL(JDBCConnection):
.. dropdown:: Version compatibility
* MySQL server versions: 5.7, 8.0
- * Spark versions: 2.3.x - 3.4.x
+ * Spark versions: 2.3.x - 3.5.x
* Java versions: 8 - 20
See `official documentation `_.
@@ -61,7 +61,7 @@ class MySQL(JDBCConnection):
pip install onetl[spark] # latest PySpark version
# or
- pip install onetl pyspark=3.4.1 # pass specific PySpark version
+ pip install onetl pyspark=3.5.0 # pass specific PySpark version
See :ref:`install-spark` installation instruction for more details.
diff --git a/onetl/connection/db_connection/oracle/connection.py b/onetl/connection/db_connection/oracle/connection.py
index 2e1f3e916..5166ee71e 100644
--- a/onetl/connection/db_connection/oracle/connection.py
+++ b/onetl/connection/db_connection/oracle/connection.py
@@ -84,7 +84,7 @@ class Oracle(JDBCConnection):
.. dropdown:: Version compatibility
* Oracle Server versions: 23c, 21c, 19c, 18c, 12.2 and probably 11.2 (tested, but that's not official).
- * Spark versions: 2.3.x - 3.4.x
+ * Spark versions: 2.3.x - 3.5.x
* Java versions: 8 - 20
See `official documentation `_.
@@ -101,7 +101,7 @@ class Oracle(JDBCConnection):
pip install onetl[spark] # latest PySpark version
# or
- pip install onetl pyspark=3.4.1 # pass specific PySpark version
+ pip install onetl pyspark=3.5.0 # pass specific PySpark version
See :ref:`install-spark` installation instruction for more details.
diff --git a/onetl/connection/db_connection/postgres/connection.py b/onetl/connection/db_connection/postgres/connection.py
index 22b42c296..a0bd508b0 100644
--- a/onetl/connection/db_connection/postgres/connection.py
+++ b/onetl/connection/db_connection/postgres/connection.py
@@ -42,7 +42,7 @@ class Postgres(JDBCConnection):
.. dropdown:: Version compatibility
* PostgreSQL server versions: 8.2 or higher
- * Spark versions: 2.3.x - 3.4.x
+ * Spark versions: 2.3.x - 3.5.x
* Java versions: 8 - 20
See `official documentation `_.
@@ -59,7 +59,7 @@ class Postgres(JDBCConnection):
pip install onetl[spark] # latest PySpark version
# or
- pip install onetl pyspark=3.4.1 # pass specific PySpark version
+ pip install onetl pyspark=3.5.0 # pass specific PySpark version
See :ref:`install-spark` installation instruction for more details.
diff --git a/onetl/connection/db_connection/teradata/connection.py b/onetl/connection/db_connection/teradata/connection.py
index 2c797b3d8..ed46d24d1 100644
--- a/onetl/connection/db_connection/teradata/connection.py
+++ b/onetl/connection/db_connection/teradata/connection.py
@@ -47,7 +47,7 @@ class Teradata(JDBCConnection):
.. dropdown:: Version compatibility
* Teradata server versions: 16.10 - 20.0
- * Spark versions: 2.3.x - 3.4.x
+ * Spark versions: 2.3.x - 3.5.x
* Java versions: 8 - 20
See `official documentation `_.
@@ -64,7 +64,7 @@ class Teradata(JDBCConnection):
pip install onetl[spark] # latest PySpark version
# or
- pip install onetl pyspark=3.4.1 # pass specific PySpark version
+ pip install onetl pyspark=3.5.0 # pass specific PySpark version
See :ref:`install-spark` installation instruction for more details.
diff --git a/onetl/connection/file_df_connection/spark_file_df_connection.py b/onetl/connection/file_df_connection/spark_file_df_connection.py
index 7c1994182..10853078b 100644
--- a/onetl/connection/file_df_connection/spark_file_df_connection.py
+++ b/onetl/connection/file_df_connection/spark_file_df_connection.py
@@ -19,7 +19,7 @@
from logging import getLogger
from typing import TYPE_CHECKING
-from pydantic import Field
+from pydantic import Field, validator
from onetl._util.hadoop import get_hadoop_config
from onetl._util.spark import try_import_pyspark
@@ -182,6 +182,19 @@ def _forward_refs(cls) -> dict[str, type]:
refs["SparkSession"] = SparkSession
return refs
+ @validator("spark")
+ def _check_spark_session_alive(cls, spark):
+ # https://stackoverflow.com/a/36044685
+ msg = "Spark session is stopped. Please recreate Spark session."
+ try:
+ if not spark._jsc.sc().isStopped():
+ return spark
+ except Exception as e:
+ # None has no attribute "something"
+ raise ValueError(msg) from e
+
+ raise ValueError(msg)
+
def _log_parameters(self):
log.info("|%s| Using connection parameters:", self.__class__.__name__)
parameters = self.dict(exclude_none=True, exclude={"spark"})
diff --git a/onetl/connection/file_df_connection/spark_hdfs/connection.py b/onetl/connection/file_df_connection/spark_hdfs/connection.py
index 6855fe595..76d1520cc 100644
--- a/onetl/connection/file_df_connection/spark_hdfs/connection.py
+++ b/onetl/connection/file_df_connection/spark_hdfs/connection.py
@@ -57,7 +57,7 @@ class SparkHDFS(SparkFileDFConnection):
pip install onetl[spark] # latest PySpark version
# or
- pip install onetl pyspark=3.4.1 # pass specific PySpark version
+ pip install onetl pyspark=3.5.0 # pass specific PySpark version
See :ref:`install-spark` installation instruction for more details.
diff --git a/onetl/connection/file_df_connection/spark_local_fs.py b/onetl/connection/file_df_connection/spark_local_fs.py
index 264fac3a2..b2fa4625e 100644
--- a/onetl/connection/file_df_connection/spark_local_fs.py
+++ b/onetl/connection/file_df_connection/spark_local_fs.py
@@ -47,7 +47,7 @@ class SparkLocalFS(SparkFileDFConnection):
pip install onetl[spark] # latest PySpark version
# or
- pip install onetl pyspark=3.4.1 # pass specific PySpark version
+ pip install onetl pyspark=3.5.0 # pass specific PySpark version
See :ref:`install-spark` installation instruction for more details.
diff --git a/onetl/connection/file_df_connection/spark_s3/connection.py b/onetl/connection/file_df_connection/spark_s3/connection.py
index 992e11627..dd881988c 100644
--- a/onetl/connection/file_df_connection/spark_s3/connection.py
+++ b/onetl/connection/file_df_connection/spark_s3/connection.py
@@ -62,7 +62,7 @@ class SparkS3(SparkFileDFConnection):
.. dropdown:: Version compatibility
- * Spark versions: 3.2.x - 3.4.x (only with Hadoop 3.x libraries)
+ * Spark versions: 3.2.x - 3.5.x (only with Hadoop 3.x libraries)
* Java versions: 8 - 20
* Scala versions: 2.11 - 2.13
@@ -82,7 +82,7 @@ class SparkS3(SparkFileDFConnection):
pip install onetl[spark] # latest PySpark version
# or
- pip install onetl pyspark=3.4.1 # pass specific PySpark version
+ pip install onetl pyspark=3.5.0 # pass specific PySpark version
See :ref:`install-spark` installation instruction for more details.
@@ -161,7 +161,7 @@ class SparkS3(SparkFileDFConnection):
from pyspark.sql import SparkSession
# Create Spark session with Hadoop AWS libraries loaded
- maven_packages = SparkS3.get_packages(spark_version="3.4.1")
+ maven_packages = SparkS3.get_packages(spark_version="3.5.0")
# Some dependencies are not used, but downloading takes a lot of time. Skipping them.
excluded_packages = [
"com.google.cloud.bigdataoss:gcs-connector",
@@ -262,8 +262,8 @@ def get_packages(
from onetl.connection import SparkS3
- SparkS3.get_packages(spark_version="3.4.1")
- SparkS3.get_packages(spark_version="3.4.1", scala_version="2.12")
+ SparkS3.get_packages(spark_version="3.5.0")
+ SparkS3.get_packages(spark_version="3.5.0", scala_version="2.12")
"""
diff --git a/onetl/file/file_downloader/options.py b/onetl/file/file_downloader/options.py
index 9ec44ce52..a1b635f81 100644
--- a/onetl/file/file_downloader/options.py
+++ b/onetl/file/file_downloader/options.py
@@ -31,8 +31,8 @@ class FileDownloaderOptions(GenericOptions):
Possible values:
* ``error`` (default) - do nothing, mark file as failed
* ``ignore`` - do nothing, mark file as ignored
- * ``overwrite`` - replace existing file with a new one
- * ``delete_all`` - delete local directory content before downloading files
+ * ``replace_file`` - replace existing file with a new one
+ * ``replace_entire_directory`` - delete local directory content before downloading files
"""
delete_source: bool = False
diff --git a/onetl/file/file_mover/options.py b/onetl/file/file_mover/options.py
index 912c0ae1b..a3a89727c 100644
--- a/onetl/file/file_mover/options.py
+++ b/onetl/file/file_mover/options.py
@@ -31,8 +31,8 @@ class FileMoverOptions(GenericOptions):
Possible values:
* ``error`` (default) - do nothing, mark file as failed
* ``ignore`` - do nothing, mark file as ignored
- * ``overwrite`` - replace existing file with a new one
- * ``delete_all`` - delete directory content before moving files
+ * ``replace_file`` - replace existing file with a new one
+ * ``replace_entire_directory`` - delete directory content before moving files
"""
workers: int = Field(default=1, ge=1)
diff --git a/onetl/file/file_uploader/options.py b/onetl/file/file_uploader/options.py
index e3dd78bb3..5d1f1dbdd 100644
--- a/onetl/file/file_uploader/options.py
+++ b/onetl/file/file_uploader/options.py
@@ -31,8 +31,8 @@ class FileUploaderOptions(GenericOptions):
Possible values:
* ``error`` (default) - do nothing, mark file as failed
* ``ignore`` - do nothing, mark file as ignored
- * ``overwrite`` - replace existing file with a new one
- * ``delete_all`` - delete local directory content before downloading files
+ * ``replace_file`` - replace existing file with a new one
+ * ``replace_entire_directory`` - delete local directory content before downloading files
"""
delete_local: bool = False
diff --git a/onetl/file/format/__init__.py b/onetl/file/format/__init__.py
index 0c9d6b742..74475c8a9 100644
--- a/onetl/file/format/__init__.py
+++ b/onetl/file/format/__init__.py
@@ -20,3 +20,4 @@
from onetl.file.format.jsonline import JSONLine
from onetl.file.format.orc import ORC
from onetl.file.format.parquet import Parquet
+from onetl.file.format.xml import XML
diff --git a/onetl/file/format/avro.py b/onetl/file/format/avro.py
index b0c58e18d..b07c34b01 100644
--- a/onetl/file/format/avro.py
+++ b/onetl/file/format/avro.py
@@ -71,7 +71,7 @@ class Avro(ReadWriteFileFormat):
.. dropdown:: Version compatibility
- * Spark versions: 2.4.x - 3.4.x
+ * Spark versions: 2.4.x - 3.5.x
* Java versions: 8 - 20
* Scala versions: 2.11 - 2.13
@@ -95,7 +95,7 @@ class Avro(ReadWriteFileFormat):
from pyspark.sql import SparkSession
# Create Spark session with Avro package loaded
- maven_packages = Avro.get_packages(spark_version="3.4.1")
+ maven_packages = Avro.get_packages(spark_version="3.5.0")
spark = (
SparkSession.builder.appName("spark-app-name")
.config("spark.jars.packages", ",".join(maven_packages))
diff --git a/onetl/file/format/xml.py b/onetl/file/format/xml.py
new file mode 100644
index 000000000..1a4c99803
--- /dev/null
+++ b/onetl/file/format/xml.py
@@ -0,0 +1,237 @@
+# Copyright 2023 MTS (Mobile Telesystems)
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from __future__ import annotations
+
+import logging
+from typing import TYPE_CHECKING, ClassVar
+
+from pydantic import Field
+
+from onetl._util.java import try_import_java_class
+from onetl._util.scala import get_default_scala_version
+from onetl._util.spark import get_spark_version
+from onetl._util.version import Version
+from onetl.exception import MISSING_JVM_CLASS_MSG
+from onetl.file.format.file_format import ReadWriteFileFormat
+from onetl.hooks import slot, support_hooks
+
+if TYPE_CHECKING:
+ from pyspark.sql import SparkSession
+
+
+PROHIBITED_OPTIONS = frozenset(
+ (
+ # filled by onETL classes
+ "path",
+ ),
+)
+
+
+READ_OPTIONS = frozenset(
+ (
+ "rowTag",
+ "samplingRatio",
+ "excludeAttribute",
+ "treatEmptyValuesAsNulls",
+ "mode",
+ "inferSchema",
+ "columnNameOfCorruptRecord",
+ "attributePrefix",
+ "valueTag",
+ "charset",
+ "ignoreSurroundingSpaces",
+ "wildcardColName",
+ "rowValidationXSDPath",
+ "ignoreNamespace",
+ "timestampFormat",
+ "dateFormat",
+ ),
+)
+
+WRITE_OPTIONS = frozenset(
+ (
+ "rowTag",
+ "rootTag",
+ "declaration",
+ "arrayElementName",
+ "nullValue",
+ "attributePrefix",
+ "valueTag",
+ "compression",
+ "timestampFormat",
+ "dateFormat",
+ ),
+)
+
+
+log = logging.getLogger(__name__)
+
+
+@support_hooks
+class XML(ReadWriteFileFormat):
+ """
+ XML file format. |support_hooks|
+
+ Based on `Databricks Spark XML `_ file format.
+
+ Supports reading/writing files with ``.xml`` extension.
+
+ .. warning::
+
+ Due to `bug `_ written files currently does not have ``.xml`` extension.
+
+ .. versionadded:: 0.9.5
+
+ .. dropdown:: Version compatibility
+
+ * Spark versions: 3.2.x - 3.5.x.
+ * Scala versions: 2.12 - 2.13
+ * Java versions: 8 - 20
+
+ See documentation from link above.
+
+ .. note ::
+
+ You can pass any option to the constructor, even if it is not mentioned in this documentation.
+ **Option names should be in** ``camelCase``!
+
+ The set of supported options depends on Spark version. See link above.
+
+ .. warning::
+
+ By default, reading is done using ``mode=PERMISSIVE`` which replaces columns with wrong data type or format with ``null`` values.
+ Be careful while parsing values like timestamps, they should match the ``timestampFormat`` option.
+ Using ``mode=FAILFAST`` will throw an exception instead of producing ``null`` values.
+ `Follow `_
+
+ Examples
+ --------
+ Describe options how to read from/write to XML file with specific options:
+
+ .. code:: python
+
+ from onetl.file.format import XML
+ from pyspark.sql import SparkSession
+
+ # Create Spark session with XML package loaded
+ maven_packages = XML.get_packages(spark_version="3.5.0")
+ spark = (
+ SparkSession.builder.appName("spark-app-name")
+ .config("spark.jars.packages", ",".join(maven_packages))
+ .getOrCreate()
+ )
+
+ xml = XML(row_tag="item")
+
+ """
+
+ name: ClassVar[str] = "xml"
+
+ row_tag: str = Field(alias="rowTag")
+
+ class Config:
+ known_options = READ_OPTIONS | WRITE_OPTIONS
+ prohibited_options = PROHIBITED_OPTIONS
+ extra = "allow"
+
+ @slot
+ @classmethod
+ def get_packages( # noqa: WPS231
+ cls,
+ spark_version: str,
+ scala_version: str | None = None,
+ package_version: str | None = None,
+ ) -> list[str]:
+ """
+ Get package names to be downloaded by Spark. |support_hooks|
+
+ Parameters
+ ----------
+ spark_version : str
+ Spark version in format ``major.minor.patch``.
+
+ scala_version : str, optional
+ Scala version in format ``major.minor``.
+
+ If ``None``, ``spark_version`` is used to determine Scala version.
+
+ version: str, optional
+ Package version in format ``major.minor.patch``. Default is ``0.17.0``.
+
+ See `Maven index `_
+ for list of available versions.
+
+ .. warning::
+
+ Version ``0.13`` and below are not supported.
+
+ .. note::
+
+ It is not guaranteed that custom package versions are supported.
+ Tests are performed only for default version.
+
+ Examples
+ --------
+
+ .. code:: python
+
+ from onetl.file.format import XML
+
+ XML.get_packages(spark_version="3.5.0")
+ XML.get_packages(spark_version="3.5.0", scala_version="2.12")
+ XML.get_packages(
+ spark_version="3.5.0",
+ scala_version="2.12",
+ package_version="0.17.0",
+ )
+
+ """
+
+ if package_version:
+ version = Version.parse(package_version)
+ if version < (0, 14):
+ raise ValueError(f"Package version must be above 0.13, got {version}")
+ log.warning("Passed custom package version %r, it is not guaranteed to be supported", package_version)
+ else:
+ version = Version.parse("0.17.0")
+
+ spark_ver = Version.parse(spark_version)
+ scala_ver = Version.parse(scala_version) if scala_version else get_default_scala_version(spark_ver)
+
+ # Ensure compatibility with Spark and Scala versions
+ if spark_ver < (3, 0):
+ raise ValueError(f"Spark version must be 3.x, got {spark_ver}")
+
+ if scala_ver < (2, 12) or scala_ver > (2, 13):
+ raise ValueError(f"Scala version must be 2.12 or 2.13, got {scala_ver}")
+
+ return [f"com.databricks:spark-xml_{scala_ver.digits(2)}:{version.digits(3)}"]
+
+ @slot
+ def check_if_supported(self, spark: SparkSession) -> None:
+ java_class = "com.databricks.spark.xml.XmlReader"
+
+ try:
+ try_import_java_class(spark, java_class)
+ except Exception as e:
+ spark_version = get_spark_version(spark)
+ msg = MISSING_JVM_CLASS_MSG.format(
+ java_class=java_class,
+ package_source=self.__class__.__name__,
+ args=f"spark_version='{spark_version}'",
+ )
+ if log.isEnabledFor(logging.DEBUG):
+ log.debug("Missing Java class", exc_info=e, stack_info=True)
+ raise ValueError(msg) from e
diff --git a/pyproject.toml b/pyproject.toml
index 3fa87a807..396b3094a 100644
--- a/pyproject.toml
+++ b/pyproject.toml
@@ -4,7 +4,7 @@ multi_line_output = 3
[tool.black]
line-length = 120
-target-version = ['py37', 'py38', 'py39', 'py310', 'py311']
+target-version = ['py37', 'py38', 'py39', 'py310', 'py311', 'py312']
include = '\.pyi?$'
exclude = '''(\.eggs|\.git|\.mypy_cache|\.tox|\.venv|_build|buck-out|build|dist)'''
diff --git a/requirements/tests/spark-3.5.0.txt b/requirements/tests/spark-3.5.0.txt
new file mode 100644
index 000000000..2e49168a5
--- /dev/null
+++ b/requirements/tests/spark-3.5.0.txt
@@ -0,0 +1,5 @@
+numpy>=1.16
+pandas>=1.0
+pyarrow>=1.0
+pyspark==3.5.0
+sqlalchemy
diff --git a/tests/.coveragerc b/tests/.coveragerc
index 17cafe455..218452499 100644
--- a/tests/.coveragerc
+++ b/tests/.coveragerc
@@ -19,3 +19,4 @@ exclude_lines =
if pyspark_version
spark = SparkSession._instantiatedSession
if log.isEnabledFor(logging.DEBUG):
+ if sys.version_info
diff --git a/tests/fixtures/connections/file_df_connections.py b/tests/fixtures/connections/file_df_connections.py
index 3a9bea44a..faf3bb3b6 100644
--- a/tests/fixtures/connections/file_df_connections.py
+++ b/tests/fixtures/connections/file_df_connections.py
@@ -56,13 +56,62 @@ def file_df_schema_str_value_last():
@pytest.fixture()
def file_df_dataframe(spark, file_df_schema):
data = [
- [1, "val1", 123, datetime.date(2021, 1, 1), datetime.datetime(2021, 1, 1, 1, 1, 1), 1.23],
- [2, "val1", 234, datetime.date(2022, 2, 2), datetime.datetime(2022, 2, 2, 2, 2, 2), 2.34],
- [3, "val2", 345, datetime.date(2023, 3, 3), datetime.datetime(2023, 3, 3, 3, 3, 3), 3.45],
- [4, "val2", 456, datetime.date(2024, 4, 4), datetime.datetime(2024, 4, 4, 4, 4, 4), 4.56],
- [5, "val3", 567, datetime.date(2025, 5, 5), datetime.datetime(2025, 5, 5, 5, 5, 5), 5.67],
- [6, "val3", 678, datetime.date(2026, 6, 6), datetime.datetime(2026, 6, 6, 6, 6, 6), 6.78],
- [7, "val3", 789, datetime.date(2027, 7, 7), datetime.datetime(2027, 7, 7, 7, 7, 7), 7.89],
+ [
+ 1,
+ "val1",
+ 123,
+ datetime.date(2021, 1, 1),
+ datetime.datetime(2021, 1, 1, 1, 1, 1, tzinfo=datetime.timezone.utc),
+ 1.23,
+ ],
+ [
+ 2,
+ "val1",
+ 234,
+ datetime.date(2022, 2, 2),
+ datetime.datetime(2022, 2, 2, 2, 2, 2, tzinfo=datetime.timezone.utc),
+ 2.34,
+ ],
+ [
+ 3,
+ "val2",
+ 345,
+ datetime.date(2023, 3, 3),
+ datetime.datetime(2023, 3, 3, 3, 3, 3, tzinfo=datetime.timezone.utc),
+ 3.45,
+ ],
+ [
+ 4,
+ "val2",
+ 456,
+ datetime.date(2024, 4, 4),
+ datetime.datetime(2024, 4, 4, 4, 4, 4, tzinfo=datetime.timezone.utc),
+ 4.56,
+ ],
+ [
+ 5,
+ "val3",
+ 567,
+ datetime.date(2025, 5, 5),
+ datetime.datetime(2025, 5, 5, 5, 5, 5, tzinfo=datetime.timezone.utc),
+ 5.67,
+ ],
+ [
+ 6,
+ "val3",
+ 678,
+ datetime.date(2026, 6, 6),
+ datetime.datetime(2026, 6, 6, 6, 6, 6, tzinfo=datetime.timezone.utc),
+ 6.78,
+ ],
+ [
+ 7,
+ "val3",
+ 789,
+ datetime.date(2027, 7, 7),
+ datetime.datetime(2027, 7, 7, 7, 7, 7, tzinfo=datetime.timezone.utc),
+ 7.89,
+ ],
]
return spark.createDataFrame(data, schema=file_df_schema)
diff --git a/tests/fixtures/spark.py b/tests/fixtures/spark.py
index 05358b9c0..452c0f978 100644
--- a/tests/fixtures/spark.py
+++ b/tests/fixtures/spark.py
@@ -44,7 +44,7 @@ def maven_packages():
SparkS3,
Teradata,
)
- from onetl.file.format import Avro, Excel
+ from onetl.file.format import XML, Avro, Excel
pyspark_version = get_pyspark_version()
packages = (
@@ -71,11 +71,16 @@ def maven_packages():
# There is no SparkS3 connector for Spark less than 3
packages.extend(SparkS3.get_packages(spark_version=pyspark_version))
+ # There is no XML files support for Spark less than 3
+ packages.extend(XML.get_packages(pyspark_version))
+
# There is no MongoDB connector for Spark less than 3.2
packages.extend(MongoDB.get_packages(spark_version=pyspark_version))
- # There is no Excel files support for Spark less than 3.2
- packages.extend(Excel.get_packages(spark_version=pyspark_version))
+ if pyspark_version < (3, 5):
+ # There is no Excel files support for Spark less than 3.2
+ # And there is still no package released for 3.5.0 https://github.com/crealytics/spark-excel/issues/787
+ packages.extend(Excel.get_packages(spark_version=pyspark_version))
return packages
diff --git a/tests/fixtures/spark_mock.py b/tests/fixtures/spark_mock.py
index b09e7764e..43085f50b 100644
--- a/tests/fixtures/spark_mock.py
+++ b/tests/fixtures/spark_mock.py
@@ -3,6 +3,23 @@
import pytest
+@pytest.fixture(
+ scope="function",
+ params=[pytest.param("mock-spark-stopped", marks=[pytest.mark.db_connection, pytest.mark.connection])],
+)
+def spark_stopped():
+ import pyspark
+ from pyspark.sql import SparkSession
+
+ spark = Mock(spec=SparkSession)
+ spark.sparkContext = Mock()
+ spark.sparkContext.appName = "abc"
+ spark.version = pyspark.__version__
+ spark._sc = Mock()
+ spark._sc._gateway = Mock()
+ return spark
+
+
@pytest.fixture(
scope="function",
params=[pytest.param("mock-spark-no-packages", marks=[pytest.mark.db_connection, pytest.mark.connection])],
@@ -15,6 +32,9 @@ def spark_no_packages():
spark.sparkContext = Mock()
spark.sparkContext.appName = "abc"
spark.version = pyspark.__version__
+ spark._jsc = Mock()
+ spark._jsc.sc = Mock()
+ spark._jsc.sc().isStopped = Mock(return_value=False)
return spark
@@ -29,7 +49,10 @@ def spark_mock():
spark = Mock(spec=SparkSession)
spark.sparkContext = Mock()
spark.sparkContext.appName = "abc"
+ spark.version = pyspark.__version__
spark._sc = Mock()
spark._sc._gateway = Mock()
- spark.version = pyspark.__version__
+ spark._jsc = Mock()
+ spark._jsc.sc = Mock()
+ spark._jsc.sc().isStopped = Mock(return_value=False)
return spark
diff --git a/tests/resources/file_df_connection/generate_files.py b/tests/resources/file_df_connection/generate_files.py
index 698c81ea7..2417cb15d 100755
--- a/tests/resources/file_df_connection/generate_files.py
+++ b/tests/resources/file_df_connection/generate_files.py
@@ -16,6 +16,7 @@
from pathlib import Path
from tempfile import gettempdir
from typing import TYPE_CHECKING, Any, Iterator, TextIO
+from xml.etree import ElementTree # noqa: S405
from zipfile import ZipFile
if TYPE_CHECKING:
@@ -472,6 +473,72 @@ def save_as_xls(data: list[dict], path: Path) -> None:
)
+def save_as_xml_plain(data: list[dict], path: Path) -> None:
+ path.mkdir(parents=True, exist_ok=True)
+ root = ElementTree.Element("root")
+
+ for record in data:
+ item = ElementTree.SubElement(root, "item")
+ for key, value in record.items():
+ child = ElementTree.SubElement(item, key)
+ if isinstance(value, datetime):
+ child.text = value.isoformat()
+ else:
+ child.text = str(value)
+
+ tree = ElementTree.ElementTree(root)
+ tree.write(path / "file.xml")
+
+
+def save_as_xml_with_attributes(data: list[dict], path: Path) -> None:
+ path.mkdir(parents=True, exist_ok=True)
+ root = ElementTree.Element("root")
+
+ for record in data:
+ str_attributes = {
+ key: value.isoformat() if isinstance(value, datetime) else str(value) for key, value in record.items()
+ }
+ item = ElementTree.SubElement(root, "item", attrib=str_attributes)
+ for key, value in record.items():
+ child = ElementTree.SubElement(item, key)
+ if isinstance(value, datetime):
+ child.text = value.isoformat()
+ else:
+ child.text = str(value)
+
+ tree = ElementTree.ElementTree(root)
+ tree.write(str(path / "file_with_attributes.xml"))
+
+
+def save_as_xml_gz(data: list[dict], path: Path) -> None:
+ path.mkdir(parents=True, exist_ok=True)
+ root = ElementTree.Element("root")
+
+ for record in data:
+ item = ElementTree.SubElement(root, "item")
+ for key, value in record.items():
+ child = ElementTree.SubElement(item, key)
+ if isinstance(value, datetime):
+ child.text = value.isoformat()
+ else:
+ child.text = str(value)
+
+ ElementTree.ElementTree(root)
+ xml_string = ElementTree.tostring(root, encoding="utf-8")
+
+ with gzip.open(path / "file.xml.gz", "wb", compresslevel=9) as f:
+ f.write(xml_string)
+
+
+def save_as_xml(data: list[dict], path: Path) -> None:
+ root = path / "xml"
+ shutil.rmtree(root, ignore_errors=True)
+
+ save_as_xml_plain(data, root / "without_compression")
+ save_as_xml_with_attributes(data, root / "with_attributes")
+ save_as_xml_gz(data, root / "with_compression")
+
+
format_mapping = {
"csv": save_as_csv,
"json": save_as_json,
@@ -481,6 +548,7 @@ def save_as_xls(data: list[dict], path: Path) -> None:
"avro": save_as_avro,
"xlsx": save_as_xlsx,
"xls": save_as_xls,
+ "xml": save_as_xml,
}
diff --git a/tests/resources/file_df_connection/xml/with_attributes/file_with_attributes.xml b/tests/resources/file_df_connection/xml/with_attributes/file_with_attributes.xml
new file mode 100644
index 000000000..f6fcbc7df
--- /dev/null
+++ b/tests/resources/file_df_connection/xml/with_attributes/file_with_attributes.xml
@@ -0,0 +1 @@
+- 1val11232021-01-012021-01-01T01:01:01+00:001.23
- 2val12342022-02-022022-02-02T02:02:02+00:002.34
- 3val23452023-03-032023-03-03T03:03:03+00:003.45
- 4val24562024-04-042024-04-04T04:04:04+00:004.56
- 5val35672025-05-052025-05-05T05:05:05+00:005.67
- 6val36782026-06-062026-06-06T06:06:06+00:006.78
- 7val37892027-07-072027-07-07T07:07:07+00:007.89
\ No newline at end of file
diff --git a/tests/resources/file_df_connection/xml/with_compression/file.xml.gz b/tests/resources/file_df_connection/xml/with_compression/file.xml.gz
new file mode 100644
index 000000000..aefbf24af
Binary files /dev/null and b/tests/resources/file_df_connection/xml/with_compression/file.xml.gz differ
diff --git a/tests/resources/file_df_connection/xml/without_compression/file.xml b/tests/resources/file_df_connection/xml/without_compression/file.xml
new file mode 100644
index 000000000..79d7a0ddb
--- /dev/null
+++ b/tests/resources/file_df_connection/xml/without_compression/file.xml
@@ -0,0 +1 @@
+- 1val11232021-01-012021-01-01T01:01:01+00:001.23
- 2val12342022-02-022022-02-02T02:02:02+00:002.34
- 3val23452023-03-032023-03-03T03:03:03+00:003.45
- 4val24562024-04-042024-04-04T04:04:04+00:004.56
- 5val35672025-05-052025-05-05T05:05:05+00:005.67
- 6val36782026-06-062026-06-06T06:06:06+00:006.78
- 7val37892027-07-072027-07-07T07:07:07+00:007.89
\ No newline at end of file
diff --git a/tests/tests_integration/test_file_format_integration/test_excel_integration.py b/tests/tests_integration/test_file_format_integration/test_excel_integration.py
index de8cc9cf9..f9aaad38f 100644
--- a/tests/tests_integration/test_file_format_integration/test_excel_integration.py
+++ b/tests/tests_integration/test_file_format_integration/test_excel_integration.py
@@ -33,6 +33,8 @@ def test_excel_reader_with_infer_schema(
spark_version = get_spark_version(spark)
if spark_version < (3, 2):
pytest.skip("Excel files are supported on Spark 3.2+ only")
+ if spark_version >= (3, 5):
+ pytest.skip("Excel files are not supported on Spark 3.5+ yet")
file_df_connection, source_path, _ = local_fs_file_df_connection_with_path_and_files
df = file_df_dataframe
@@ -81,6 +83,8 @@ def test_excel_reader_with_options(
spark_version = get_spark_version(spark)
if spark_version < (3, 2):
pytest.skip("Excel files are supported on Spark 3.2+ only")
+ if spark_version >= (3, 5):
+ pytest.skip("Excel files are not supported on Spark 3.5+ yet")
local_fs, source_path, _ = local_fs_file_df_connection_with_path_and_files
df = file_df_dataframe
@@ -117,6 +121,8 @@ def test_excel_writer(
spark_version = get_spark_version(spark)
if spark_version < (3, 2):
pytest.skip("Excel files are supported on Spark 3.2+ only")
+ if spark_version >= (3, 5):
+ pytest.skip("Excel files are not supported on Spark 3.5+ yet")
file_df_connection, source_path = local_fs_file_df_connection_with_path
df = file_df_dataframe
diff --git a/tests/tests_integration/test_file_format_integration/test_xml_integration.py b/tests/tests_integration/test_file_format_integration/test_xml_integration.py
new file mode 100644
index 000000000..d03a6f61d
--- /dev/null
+++ b/tests/tests_integration/test_file_format_integration/test_xml_integration.py
@@ -0,0 +1,169 @@
+"""Integration tests for XML file format.
+
+Test only that options are passed to Spark in both FileDFReader & FileDFWriter.
+Do not test all the possible options and combinations, we are not testing Spark here.
+"""
+
+import pytest
+
+from onetl._util.spark import get_spark_version
+from onetl.file import FileDFReader, FileDFWriter
+from onetl.file.format import XML
+
+try:
+ from tests.util.assert_df import assert_equal_df
+except ImportError:
+ # pandas and spark can be missing if someone runs tests for file connections only
+ pass
+
+pytestmark = [pytest.mark.local_fs, pytest.mark.file_df_connection, pytest.mark.connection]
+
+
+@pytest.fixture()
+def expected_xml_attributes_df(file_df_dataframe):
+ col_names = file_df_dataframe.columns
+ exprs = [f"{col} as _{col}" for col in col_names] + col_names
+ return file_df_dataframe.selectExpr(*exprs)
+
+
+@pytest.mark.parametrize(
+ "path, options",
+ [
+ ("without_compression", {"rowTag": "item"}),
+ ("with_compression", {"rowTag": "item", "compression": "gzip"}),
+ ("with_attributes", {"rowTag": "item", "attributePrefix": "_"}),
+ ],
+ ids=["without_compression", "with_compression", "with_attributes"],
+)
+def test_xml_reader(
+ spark,
+ local_fs_file_df_connection_with_path_and_files,
+ file_df_dataframe,
+ path,
+ options,
+):
+ """Reading XML files working as expected on any Spark, Python and Java versions"""
+ spark_version = get_spark_version(spark)
+ if spark_version < (3, 0):
+ pytest.skip("XML files are supported on Spark 3.x only")
+
+ local_fs, source_path, _ = local_fs_file_df_connection_with_path_and_files
+ df = file_df_dataframe
+ xml_root = source_path / "xml" / path
+
+ reader = FileDFReader(
+ connection=local_fs,
+ format=XML.parse(options),
+ df_schema=df.schema,
+ source_path=xml_root,
+ )
+ read_df = reader.run()
+ assert read_df.count()
+ assert read_df.schema == df.schema
+ assert_equal_df(read_df, df)
+
+
+def test_xml_reader_with_infer_schema(
+ spark,
+ local_fs_file_df_connection_with_path_and_files,
+ expected_xml_attributes_df,
+ file_df_dataframe,
+):
+ """Reading XML files with inferSchema=True working as expected on any Spark, Python and Java versions"""
+ spark_version = get_spark_version(spark)
+ if spark_version < (3, 0):
+ pytest.skip("XML files are supported on Spark 3.x only")
+
+ file_df_connection, source_path, _ = local_fs_file_df_connection_with_path_and_files
+ df = file_df_dataframe
+ xml_root = source_path / "xml" / "with_attributes"
+
+ reader = FileDFReader(
+ connection=file_df_connection,
+ format=XML(rowTag="item", inferSchema=True),
+ source_path=xml_root,
+ )
+ read_df = reader.run()
+
+ assert read_df.count()
+ assert read_df.schema != df.schema
+ assert set(read_df.columns) == set(
+ expected_xml_attributes_df.columns,
+ ) # "DataFrames have different column types: StructField('id', IntegerType(), True), StructField('id', LongType(), True), etc."
+ assert_equal_df(read_df, expected_xml_attributes_df)
+
+
+@pytest.mark.parametrize(
+ "options",
+ [
+ {"rowTag": "item", "rootTag": "root"},
+ {"rowTag": "item", "rootTag": "root", "compression": "gzip"},
+ ],
+ ids=["without_compression", "with_compression"],
+)
+def test_xml_writer(
+ spark,
+ local_fs_file_df_connection_with_path,
+ file_df_dataframe,
+ options,
+):
+ """Written files can be read by Spark"""
+ spark_version = get_spark_version(spark)
+ if spark_version < (3, 0):
+ pytest.skip("XML files are supported on Spark 3.x only")
+
+ file_df_connection, source_path = local_fs_file_df_connection_with_path
+ df = file_df_dataframe
+ xml_root = source_path / "xml"
+
+ writer = FileDFWriter(
+ connection=file_df_connection,
+ format=XML.parse(options),
+ target_path=xml_root,
+ )
+ writer.run(df)
+
+ reader = FileDFReader(
+ connection=file_df_connection,
+ format=XML.parse(options),
+ source_path=xml_root,
+ df_schema=df.schema,
+ )
+ read_df = reader.run()
+
+ assert read_df.count()
+ assert read_df.schema == df.schema
+ assert_equal_df(read_df, df)
+
+
+@pytest.mark.parametrize(
+ "options",
+ [
+ {"rowTag": "item", "attributePrefix": "_"},
+ ],
+ ids=["read_attributes"],
+)
+def test_xml_reader_with_attributes(
+ spark,
+ local_fs_file_df_connection_with_path_and_files,
+ expected_xml_attributes_df,
+ options,
+):
+ """Reading XML files with attributes works as expected"""
+ spark_version = get_spark_version(spark)
+ if spark_version < (3, 0):
+ pytest.skip("XML files are supported on Spark 3.x only")
+
+ local_fs, source_path, _ = local_fs_file_df_connection_with_path_and_files
+ xml_root = source_path / "xml" / "with_attributes"
+
+ reader = FileDFReader(
+ connection=local_fs,
+ format=XML.parse(options),
+ df_schema=expected_xml_attributes_df.schema,
+ source_path=xml_root,
+ )
+ read_df = reader.run()
+ assert read_df.count()
+ assert read_df.schema == expected_xml_attributes_df.schema
+ assert_equal_df(read_df, expected_xml_attributes_df)
diff --git a/tests/tests_unit/test_db/test_db_reader_unit/test_clickhouse_reader_unit.py b/tests/tests_unit/test_db/test_db_reader_unit/test_clickhouse_reader_unit.py
index ba16f231b..598e9d1b5 100644
--- a/tests/tests_unit/test_db/test_db_reader_unit/test_clickhouse_reader_unit.py
+++ b/tests/tests_unit/test_db/test_db_reader_unit/test_clickhouse_reader_unit.py
@@ -42,7 +42,7 @@ def test_clickhouse_reader_wrong_table_name(spark_mock, table):
with pytest.raises(ValueError, match="Name should be passed in `schema.name` format"):
DBReader(
connection=clickhouse,
- table=table, # Required format: table="shema.table"
+ table=table, # Required format: table="schema.table"
)
diff --git a/tests/tests_unit/test_db/test_db_reader_unit/test_greenplum_reader_unit.py b/tests/tests_unit/test_db/test_db_reader_unit/test_greenplum_reader_unit.py
index 3770bcf19..4f4820d4d 100644
--- a/tests/tests_unit/test_db/test_db_reader_unit/test_greenplum_reader_unit.py
+++ b/tests/tests_unit/test_db/test_db_reader_unit/test_greenplum_reader_unit.py
@@ -43,7 +43,7 @@ def test_greenplum_reader_wrong_table_name(spark_mock, table):
with pytest.raises(ValueError, match="Name should be passed in `schema.name` format"):
DBReader(
connection=greenplum,
- table=table, # Required format: table="shema.table"
+ table=table, # Required format: table="schema.table"
)
diff --git a/tests/tests_unit/test_db/test_db_reader_unit/test_hive_reader_unit.py b/tests/tests_unit/test_db/test_db_reader_unit/test_hive_reader_unit.py
index ecfdf7bee..acb467e3b 100644
--- a/tests/tests_unit/test_db/test_db_reader_unit/test_hive_reader_unit.py
+++ b/tests/tests_unit/test_db/test_db_reader_unit/test_hive_reader_unit.py
@@ -43,7 +43,7 @@ def test_hive_reader_wrong_table_name(spark_mock, table):
with pytest.raises(ValueError):
DBReader(
connection=hive,
- table=table, # Required format: table="shema.table"
+ table=table, # Required format: table="schema.table"
)
diff --git a/tests/tests_unit/test_db/test_db_reader_unit/test_mssql_reader_unit.py b/tests/tests_unit/test_db/test_db_reader_unit/test_mssql_reader_unit.py
index c8c483447..ea9953e68 100644
--- a/tests/tests_unit/test_db/test_db_reader_unit/test_mssql_reader_unit.py
+++ b/tests/tests_unit/test_db/test_db_reader_unit/test_mssql_reader_unit.py
@@ -43,7 +43,7 @@ def test_mssql_reader_wrong_table_name(spark_mock, table):
with pytest.raises(ValueError, match="Name should be passed in `schema.name` format"):
DBReader(
connection=mssql,
- table=table, # Required format: table="shema.table"
+ table=table, # Required format: table="schema.table"
)
diff --git a/tests/tests_unit/test_db/test_db_reader_unit/test_mysql_reader_unit.py b/tests/tests_unit/test_db/test_db_reader_unit/test_mysql_reader_unit.py
index 19c4c662f..8976393e3 100644
--- a/tests/tests_unit/test_db/test_db_reader_unit/test_mysql_reader_unit.py
+++ b/tests/tests_unit/test_db/test_db_reader_unit/test_mysql_reader_unit.py
@@ -43,7 +43,7 @@ def test_mysql_reader_wrong_table_name(spark_mock, table):
with pytest.raises(ValueError, match="Name should be passed in `schema.name` format"):
DBReader(
connection=mysql,
- table=table, # Required format: table="shema.table"
+ table=table, # Required format: table="schema.table"
)
diff --git a/tests/tests_unit/test_db/test_db_reader_unit/test_oracle_reader_unit.py b/tests/tests_unit/test_db/test_db_reader_unit/test_oracle_reader_unit.py
index 1011f17ec..444bc596f 100644
--- a/tests/tests_unit/test_db/test_db_reader_unit/test_oracle_reader_unit.py
+++ b/tests/tests_unit/test_db/test_db_reader_unit/test_oracle_reader_unit.py
@@ -42,7 +42,7 @@ def test_oracle_reader_wrong_table_name(spark_mock, table):
with pytest.raises(ValueError, match="Name should be passed in `schema.name` format"):
DBReader(
connection=oracle,
- table=table, # Required format: table="shema.table"
+ table=table, # Required format: table="schema.table"
)
diff --git a/tests/tests_unit/test_db/test_db_reader_unit/test_postgres_reader_unit.py b/tests/tests_unit/test_db/test_db_reader_unit/test_postgres_reader_unit.py
index c541e33cd..03c646811 100644
--- a/tests/tests_unit/test_db/test_db_reader_unit/test_postgres_reader_unit.py
+++ b/tests/tests_unit/test_db/test_db_reader_unit/test_postgres_reader_unit.py
@@ -43,7 +43,7 @@ def test_postgres_reader_wrong_table_name(spark_mock, table):
with pytest.raises(ValueError, match="Name should be passed in `schema.name` format"):
DBReader(
connection=postgres,
- table=table, # Required format: table="shema.table"
+ table=table, # Required format: table="schema.table"
)
diff --git a/tests/tests_unit/test_db/test_db_reader_unit/test_teradata_reader_unit.py b/tests/tests_unit/test_db/test_db_reader_unit/test_teradata_reader_unit.py
index 262897fde..c61abca72 100644
--- a/tests/tests_unit/test_db/test_db_reader_unit/test_teradata_reader_unit.py
+++ b/tests/tests_unit/test_db/test_db_reader_unit/test_teradata_reader_unit.py
@@ -43,7 +43,7 @@ def test_teradata_reader_wrong_table_name(spark_mock, table):
with pytest.raises(ValueError, match="Name should be passed in `schema.name` format"):
DBReader(
connection=teradata,
- table=table, # Required format: table="shema.table"
+ table=table, # Required format: table="schema.table"
)
diff --git a/tests/tests_unit/test_db/test_db_writer_unit/test_clickhouse_writer_unit.py b/tests/tests_unit/test_db/test_db_writer_unit/test_clickhouse_writer_unit.py
index eb4b34028..a631c95b7 100644
--- a/tests/tests_unit/test_db/test_db_writer_unit/test_clickhouse_writer_unit.py
+++ b/tests/tests_unit/test_db/test_db_writer_unit/test_clickhouse_writer_unit.py
@@ -13,5 +13,5 @@ def test_clickhouse_writer_wrong_table_name(spark_mock, table):
with pytest.raises(ValueError, match="Name should be passed in `schema.name` format"):
DBWriter(
connection=clickhouse,
- table=table, # Required format: table="shema.table"
+ table=table, # Required format: table="schema.table"
)
diff --git a/tests/tests_unit/test_db/test_db_writer_unit/test_greenplum_writer_unit.py b/tests/tests_unit/test_db/test_db_writer_unit/test_greenplum_writer_unit.py
index fb3614cdd..784d43580 100644
--- a/tests/tests_unit/test_db/test_db_writer_unit/test_greenplum_writer_unit.py
+++ b/tests/tests_unit/test_db/test_db_writer_unit/test_greenplum_writer_unit.py
@@ -13,5 +13,5 @@ def test_greenplum_writer_wrong_table_name(spark_mock, table):
with pytest.raises(ValueError, match="Name should be passed in `schema.name` format"):
DBWriter(
connection=greenplum,
- table=table, # Required format: table="shema.table"
+ table=table, # Required format: table="schema.table"
)
diff --git a/tests/tests_unit/test_db/test_db_writer_unit/test_hive_writer_unit.py b/tests/tests_unit/test_db/test_db_writer_unit/test_hive_writer_unit.py
index 2fbfdb573..c6c6ddc70 100644
--- a/tests/tests_unit/test_db/test_db_writer_unit/test_hive_writer_unit.py
+++ b/tests/tests_unit/test_db/test_db_writer_unit/test_hive_writer_unit.py
@@ -13,5 +13,5 @@ def test_hive_writer_wrong_table_name(spark_mock, table):
with pytest.raises(ValueError, match="Name should be passed in `schema.name` format"):
DBWriter(
connection=hive,
- table=table, # Required format: table="shema.table"
+ table=table, # Required format: table="schema.table"
)
diff --git a/tests/tests_unit/test_db/test_db_writer_unit/test_mssql_writer_unit.py b/tests/tests_unit/test_db/test_db_writer_unit/test_mssql_writer_unit.py
index 44618ff11..0690e6a45 100644
--- a/tests/tests_unit/test_db/test_db_writer_unit/test_mssql_writer_unit.py
+++ b/tests/tests_unit/test_db/test_db_writer_unit/test_mssql_writer_unit.py
@@ -13,5 +13,5 @@ def test_mssql_writer_wrong_table_name(spark_mock, table):
with pytest.raises(ValueError, match="Name should be passed in `schema.name` format"):
DBWriter(
connection=mssql,
- table=table, # Required format: table="shema.table"
+ table=table, # Required format: table="schema.table"
)
diff --git a/tests/tests_unit/test_db/test_db_writer_unit/test_mysql_writer_unit.py b/tests/tests_unit/test_db/test_db_writer_unit/test_mysql_writer_unit.py
index 8eb54f397..fc00e96c3 100644
--- a/tests/tests_unit/test_db/test_db_writer_unit/test_mysql_writer_unit.py
+++ b/tests/tests_unit/test_db/test_db_writer_unit/test_mysql_writer_unit.py
@@ -13,5 +13,5 @@ def test_mysql_writer_wrong_table_name(spark_mock, table):
with pytest.raises(ValueError, match="Name should be passed in `schema.name` format"):
DBWriter(
connection=mysql,
- table=table, # Required format: table="shema.table"
+ table=table, # Required format: table="schema.table"
)
diff --git a/tests/tests_unit/test_db/test_db_writer_unit/test_oracle_writer_unit.py b/tests/tests_unit/test_db/test_db_writer_unit/test_oracle_writer_unit.py
index 63668cacf..ae53e4515 100644
--- a/tests/tests_unit/test_db/test_db_writer_unit/test_oracle_writer_unit.py
+++ b/tests/tests_unit/test_db/test_db_writer_unit/test_oracle_writer_unit.py
@@ -13,5 +13,5 @@ def test_oracle_writer_wrong_table_name(spark_mock, table):
with pytest.raises(ValueError, match="Name should be passed in `schema.name` format"):
DBWriter(
connection=oracle,
- table=table, # Required format: table="shema.table"
+ table=table, # Required format: table="schema.table"
)
diff --git a/tests/tests_unit/test_db/test_db_writer_unit/test_postgres_writer_unit.py b/tests/tests_unit/test_db/test_db_writer_unit/test_postgres_writer_unit.py
index d7322b5f1..a2794466f 100644
--- a/tests/tests_unit/test_db/test_db_writer_unit/test_postgres_writer_unit.py
+++ b/tests/tests_unit/test_db/test_db_writer_unit/test_postgres_writer_unit.py
@@ -13,5 +13,5 @@ def test_postgres_writer_wrong_table_name(spark_mock, table):
with pytest.raises(ValueError, match="Name should be passed in `schema.name` format"):
DBWriter(
connection=postgres,
- table=table, # Required format: table="shema.table"
+ table=table, # Required format: table="schema.table"
)
diff --git a/tests/tests_unit/test_db/test_db_writer_unit/test_teradata_writer_unit.py b/tests/tests_unit/test_db/test_db_writer_unit/test_teradata_writer_unit.py
index 6b9434ba7..76bc48358 100644
--- a/tests/tests_unit/test_db/test_db_writer_unit/test_teradata_writer_unit.py
+++ b/tests/tests_unit/test_db/test_db_writer_unit/test_teradata_writer_unit.py
@@ -13,5 +13,5 @@ def test_teradata_writer_wrong_table_name(spark_mock, table):
with pytest.raises(ValueError, match="Name should be passed in `schema.name` format"):
DBWriter(
connection=teradata,
- table=table, # Required format: table="shema.table"
+ table=table, # Required format: table="schema.table"
)
diff --git a/tests/tests_unit/test_file/test_format_unit/test_avro_unit.py b/tests/tests_unit/test_file/test_format_unit/test_avro_unit.py
index dcf53e7f7..97cb42dc8 100644
--- a/tests/tests_unit/test_file/test_format_unit/test_avro_unit.py
+++ b/tests/tests_unit/test_file/test_format_unit/test_avro_unit.py
@@ -27,12 +27,12 @@ def test_avro_get_packages_scala_version_not_supported():
[
# Detect Scala version by Spark version
("2.4.0", None, "org.apache.spark:spark-avro_2.11:2.4.0"),
- ("3.4.0", None, "org.apache.spark:spark-avro_2.12:3.4.0"),
+ ("3.5.0", None, "org.apache.spark:spark-avro_2.12:3.5.0"),
# Override Scala version
("2.4.0", "2.11", "org.apache.spark:spark-avro_2.11:2.4.0"),
("2.4.0", "2.12", "org.apache.spark:spark-avro_2.12:2.4.0"),
- ("3.4.0", "2.12", "org.apache.spark:spark-avro_2.12:3.4.0"),
- ("3.4.0", "2.13", "org.apache.spark:spark-avro_2.13:3.4.0"),
+ ("3.5.0", "2.12", "org.apache.spark:spark-avro_2.12:3.5.0"),
+ ("3.5.0", "2.13", "org.apache.spark:spark-avro_2.13:3.5.0"),
],
)
def test_avro_get_packages(spark_version, scala_version, package):
diff --git a/tests/tests_unit/test_file/test_format_unit/test_xml_unit.py b/tests/tests_unit/test_file/test_format_unit/test_xml_unit.py
new file mode 100644
index 000000000..17d9dac2a
--- /dev/null
+++ b/tests/tests_unit/test_file/test_format_unit/test_xml_unit.py
@@ -0,0 +1,129 @@
+import logging
+
+import pytest
+
+from onetl.file.format import XML
+
+
+@pytest.mark.parametrize(
+ "spark_version, scala_version, package_version, expected_packages",
+ [
+ ("3.2.4", None, None, ["com.databricks:spark-xml_2.12:0.17.0"]),
+ ("3.4.1", "2.12", "0.18.0", ["com.databricks:spark-xml_2.12:0.18.0"]),
+ ("3.0.0", None, None, ["com.databricks:spark-xml_2.12:0.17.0"]),
+ ("3.0.0", "2.12", "0.17.0", ["com.databricks:spark-xml_2.12:0.17.0"]),
+ ("3.1.2", None, None, ["com.databricks:spark-xml_2.12:0.17.0"]),
+ ("3.1.2", "2.12", "0.16.0", ["com.databricks:spark-xml_2.12:0.16.0"]),
+ ("3.2.0", "2.12", None, ["com.databricks:spark-xml_2.12:0.17.0"]),
+ ("3.2.0", "2.12", "0.15.0", ["com.databricks:spark-xml_2.12:0.15.0"]),
+ ("3.2.4", "2.13", None, ["com.databricks:spark-xml_2.13:0.17.0"]),
+ ("3.4.1", "2.13", "0.18.0", ["com.databricks:spark-xml_2.13:0.18.0"]),
+ ("3.3.0", None, "0.16.0", ["com.databricks:spark-xml_2.12:0.16.0"]),
+ ("3.3.0", "2.12", None, ["com.databricks:spark-xml_2.12:0.17.0"]),
+ ],
+)
+def test_xml_get_packages(spark_version, scala_version, package_version, expected_packages):
+ result = XML.get_packages(
+ spark_version=spark_version,
+ scala_version=scala_version,
+ package_version=package_version,
+ )
+ assert result == expected_packages
+
+
+@pytest.mark.parametrize(
+ "spark_version, scala_version, package_version",
+ [
+ ("2.4.8", None, None),
+ ("2.3.4", None, None),
+ ],
+)
+def test_xml_get_packages_restriction_for_spark_2x(spark_version, scala_version, package_version):
+ with pytest.raises(ValueError, match=r"Spark version must be 3.x, got \d+\.\d+"):
+ XML.get_packages(
+ spark_version=spark_version,
+ scala_version=scala_version,
+ package_version=package_version,
+ )
+
+
+@pytest.mark.parametrize(
+ "spark_version, scala_version, package_version",
+ [
+ ("3.2.4", "2.11", None),
+ ("3.4.1", "2.14", None),
+ ],
+)
+def test_xml_get_packages_scala_version_error(spark_version, scala_version, package_version):
+ with pytest.raises(ValueError, match=r"Scala version must be 2.12 or 2.13, got \d+\.\d+"):
+ XML.get_packages(
+ spark_version=spark_version,
+ scala_version=scala_version,
+ package_version=package_version,
+ )
+
+
+@pytest.mark.parametrize(
+ "spark_version, scala_version, package_version",
+ [
+ ("3.2.4", "2.12", "0.13.0"),
+ ("3.4.1", "2.12", "0.10.0"),
+ ],
+)
+def test_xml_get_packages_package_version_error(spark_version, scala_version, package_version):
+ with pytest.raises(ValueError, match=r"Package version must be above 0.13, got \d+\.\d+\.\d+"):
+ XML.get_packages(
+ spark_version=spark_version,
+ scala_version=scala_version,
+ package_version=package_version,
+ )
+
+
+@pytest.mark.parametrize(
+ "known_option",
+ [
+ "samplingRatio",
+ "excludeAttribute",
+ "treatEmptyValuesAsNulls",
+ "mode",
+ "inferSchema",
+ "columnNameOfCorruptRecord",
+ "attributePrefix",
+ "valueTag",
+ "charset",
+ "ignoreSurroundingSpaces",
+ "wildcardColName",
+ "rowValidationXSDPath",
+ "ignoreNamespace",
+ "timestampFormat",
+ "dateFormat",
+ "rootTag",
+ "declaration",
+ "arrayElementName",
+ "nullValue",
+ "compression",
+ ],
+)
+def test_xml_options_known(known_option):
+ xml = XML.parse({known_option: "value", "row_tag": "item"})
+ assert getattr(xml, known_option) == "value"
+
+
+def test_xml_option_path_error(caplog):
+ msg = r"Options \['path'\] are not allowed to use in a XML"
+ with pytest.raises(ValueError, match=msg):
+ XML(row_tag="item", path="/path")
+
+
+def test_xml_options_unknown(caplog):
+ with caplog.at_level(logging.WARNING):
+ xml = XML(row_tag="item", unknownOption="abc")
+ assert xml.unknownOption == "abc"
+ assert "Options ['unknownOption'] are not known by XML, are you sure they are valid?" in caplog.text
+
+
+@pytest.mark.local_fs
+def test_xml_missing_package(spark_no_packages):
+ msg = "Cannot import Java class 'com.databricks.spark.xml.XmlReader'"
+ with pytest.raises(ValueError, match=msg):
+ XML(row_tag="item").check_if_supported(spark_no_packages)
diff --git a/tests/tests_unit/tests_db_connection_unit/test_clickhouse_unit.py b/tests/tests_unit/tests_db_connection_unit/test_clickhouse_unit.py
index 4451ef104..42b5582ae 100644
--- a/tests/tests_unit/tests_db_connection_unit/test_clickhouse_unit.py
+++ b/tests/tests_unit/tests_db_connection_unit/test_clickhouse_unit.py
@@ -33,6 +33,18 @@ def test_clickhouse_missing_package(spark_no_packages):
)
+def test_clickhouse_spark_stopped(spark_stopped):
+ msg = "Spark session is stopped. Please recreate Spark session."
+ with pytest.raises(ValueError, match=msg):
+ Clickhouse(
+ host="some_host",
+ user="user",
+ database="database",
+ password="passwd",
+ spark=spark_stopped,
+ )
+
+
def test_clickhouse(spark_mock):
conn = Clickhouse(host="some_host", user="user", database="database", password="passwd", spark=spark_mock)
diff --git a/tests/tests_unit/tests_db_connection_unit/test_greenplum_unit.py b/tests/tests_unit/tests_db_connection_unit/test_greenplum_unit.py
index f54eec0a4..276a3c892 100644
--- a/tests/tests_unit/tests_db_connection_unit/test_greenplum_unit.py
+++ b/tests/tests_unit/tests_db_connection_unit/test_greenplum_unit.py
@@ -31,6 +31,7 @@ def test_greenplum_get_packages_no_input():
"2.2",
"3.3",
"3.4",
+ "3.5",
],
)
def test_greenplum_get_packages_spark_version_not_supported(spark_version):
@@ -82,6 +83,18 @@ def test_greenplum_missing_package(spark_no_packages):
)
+def test_greenplum_spark_stopped(spark_stopped):
+ msg = "Spark session is stopped. Please recreate Spark session."
+ with pytest.raises(ValueError, match=msg):
+ Greenplum(
+ host="some_host",
+ user="user",
+ database="database",
+ password="passwd",
+ spark=spark_stopped,
+ )
+
+
def test_greenplum(spark_mock):
conn = Greenplum(host="some_host", user="user", database="database", password="passwd", spark=spark_mock)
diff --git a/tests/tests_unit/tests_db_connection_unit/test_hive_unit.py b/tests/tests_unit/tests_db_connection_unit/test_hive_unit.py
index 6469b10c8..01cffd4a8 100644
--- a/tests/tests_unit/tests_db_connection_unit/test_hive_unit.py
+++ b/tests/tests_unit/tests_db_connection_unit/test_hive_unit.py
@@ -26,6 +26,12 @@ def test_hive_instance_url(spark_mock):
assert hive.instance_url == "some-cluster"
+def test_hive_spark_stopped(spark_stopped):
+ msg = "Spark session is stopped. Please recreate Spark session."
+ with pytest.raises(ValueError, match=msg):
+ Hive(cluster="some-cluster", spark=spark_stopped)
+
+
def test_hive_get_known_clusters_hook(request, spark_mock):
# no exception
Hive(cluster="unknown", spark=spark_mock)
@@ -60,8 +66,6 @@ def normalize_cluster_name(cluster: str) -> str:
def test_hive_known_get_current_cluster_hook(request, spark_mock, mocker):
- mocker.patch.object(Hive, "_execute_sql", return_value=None)
-
# no exception
Hive(cluster="rnd-prod", spark=spark_mock).check()
Hive(cluster="rnd-dwh", spark=spark_mock).check()
diff --git a/tests/tests_unit/tests_db_connection_unit/test_kafka_unit.py b/tests/tests_unit/tests_db_connection_unit/test_kafka_unit.py
index 1524d8ad7..404ca57fa 100644
--- a/tests/tests_unit/tests_db_connection_unit/test_kafka_unit.py
+++ b/tests/tests_unit/tests_db_connection_unit/test_kafka_unit.py
@@ -70,6 +70,16 @@ def test_kafka_missing_package(spark_no_packages):
)
+def test_kafka_spark_stopped(spark_stopped):
+ msg = "Spark session is stopped. Please recreate Spark session."
+ with pytest.raises(ValueError, match=msg):
+ Kafka(
+ cluster="some_cluster",
+ addresses=["192.168.1.1"],
+ spark=spark_stopped,
+ )
+
+
@pytest.mark.parametrize(
"option, value",
[
diff --git a/tests/tests_unit/tests_db_connection_unit/test_mongodb_unit.py b/tests/tests_unit/tests_db_connection_unit/test_mongodb_unit.py
index eb3f1db23..d53b4d614 100644
--- a/tests/tests_unit/tests_db_connection_unit/test_mongodb_unit.py
+++ b/tests/tests_unit/tests_db_connection_unit/test_mongodb_unit.py
@@ -79,6 +79,18 @@ def test_mongodb_missing_package(spark_no_packages):
)
+def test_mongodb_spark_stopped(spark_stopped):
+ msg = "Spark session is stopped. Please recreate Spark session."
+ with pytest.raises(ValueError, match=msg):
+ MongoDB(
+ host="host",
+ user="user",
+ password="password",
+ database="database",
+ spark=spark_stopped,
+ )
+
+
def test_mongodb(spark_mock):
conn = MongoDB(
host="host",
diff --git a/tests/tests_unit/tests_db_connection_unit/test_mssql_unit.py b/tests/tests_unit/tests_db_connection_unit/test_mssql_unit.py
index e6cd8eb89..7b0328ca9 100644
--- a/tests/tests_unit/tests_db_connection_unit/test_mssql_unit.py
+++ b/tests/tests_unit/tests_db_connection_unit/test_mssql_unit.py
@@ -53,6 +53,18 @@ def test_mssql_missing_package(spark_no_packages):
)
+def test_mssql_spark_stopped(spark_stopped):
+ msg = "Spark session is stopped. Please recreate Spark session."
+ with pytest.raises(ValueError, match=msg):
+ MSSQL(
+ host="some_host",
+ user="user",
+ database="database",
+ password="passwd",
+ spark=spark_stopped,
+ )
+
+
def test_mssql(spark_mock):
conn = MSSQL(host="some_host", user="user", database="database", password="passwd", spark=spark_mock)
diff --git a/tests/tests_unit/tests_db_connection_unit/test_mysql_unit.py b/tests/tests_unit/tests_db_connection_unit/test_mysql_unit.py
index 2a33c1523..ed730c418 100644
--- a/tests/tests_unit/tests_db_connection_unit/test_mysql_unit.py
+++ b/tests/tests_unit/tests_db_connection_unit/test_mysql_unit.py
@@ -33,6 +33,18 @@ def test_mysql_missing_package(spark_no_packages):
)
+def test_mysql_spark_stopped(spark_stopped):
+ msg = "Spark session is stopped. Please recreate Spark session."
+ with pytest.raises(ValueError, match=msg):
+ MySQL(
+ host="some_host",
+ user="user",
+ database="database",
+ password="passwd",
+ spark=spark_stopped,
+ )
+
+
def test_mysql(spark_mock):
conn = MySQL(host="some_host", user="user", database="database", password="passwd", spark=spark_mock)
diff --git a/tests/tests_unit/tests_db_connection_unit/test_oracle_unit.py b/tests/tests_unit/tests_db_connection_unit/test_oracle_unit.py
index bddc14c0f..6a875b8f7 100644
--- a/tests/tests_unit/tests_db_connection_unit/test_oracle_unit.py
+++ b/tests/tests_unit/tests_db_connection_unit/test_oracle_unit.py
@@ -53,6 +53,18 @@ def test_oracle_missing_package(spark_no_packages):
)
+def test_oracle_spark_stopped(spark_stopped):
+ msg = "Spark session is stopped. Please recreate Spark session."
+ with pytest.raises(ValueError, match=msg):
+ Oracle(
+ host="some_host",
+ user="user",
+ sid="sid",
+ password="passwd",
+ spark=spark_stopped,
+ )
+
+
def test_oracle(spark_mock):
conn = Oracle(host="some_host", user="user", sid="sid", password="passwd", spark=spark_mock)
diff --git a/tests/tests_unit/tests_db_connection_unit/test_postgres_unit.py b/tests/tests_unit/tests_db_connection_unit/test_postgres_unit.py
index 228c94753..01f85eb08 100644
--- a/tests/tests_unit/tests_db_connection_unit/test_postgres_unit.py
+++ b/tests/tests_unit/tests_db_connection_unit/test_postgres_unit.py
@@ -33,6 +33,18 @@ def test_oracle_missing_package(spark_no_packages):
)
+def test_postgres_spark_stopped(spark_stopped):
+ msg = "Spark session is stopped. Please recreate Spark session."
+ with pytest.raises(ValueError, match=msg):
+ Postgres(
+ host="some_host",
+ user="user",
+ database="database",
+ password="passwd",
+ spark=spark_stopped,
+ )
+
+
def test_postgres(spark_mock):
conn = Postgres(host="some_host", user="user", database="database", password="passwd", spark=spark_mock)
diff --git a/tests/tests_unit/tests_db_connection_unit/test_teradata_unit.py b/tests/tests_unit/tests_db_connection_unit/test_teradata_unit.py
index 1daf14dc4..dd9ba525d 100644
--- a/tests/tests_unit/tests_db_connection_unit/test_teradata_unit.py
+++ b/tests/tests_unit/tests_db_connection_unit/test_teradata_unit.py
@@ -33,6 +33,18 @@ def test_teradata_missing_package(spark_no_packages):
)
+def test_teradata_spark_stopped(spark_stopped):
+ msg = "Spark session is stopped. Please recreate Spark session."
+ with pytest.raises(ValueError, match=msg):
+ Teradata(
+ host="some_host",
+ user="user",
+ database="database",
+ password="passwd",
+ spark=spark_stopped,
+ )
+
+
def test_teradata(spark_mock):
conn = Teradata(host="some_host", user="user", database="database", password="passwd", spark=spark_mock)
diff --git a/tests/tests_unit/tests_file_df_connection_unit/test_spark_hdfs_unit.py b/tests/tests_unit/tests_file_df_connection_unit/test_spark_hdfs_unit.py
index 5e85c16f1..08ca6c1f4 100644
--- a/tests/tests_unit/tests_file_df_connection_unit/test_spark_hdfs_unit.py
+++ b/tests/tests_unit/tests_file_df_connection_unit/test_spark_hdfs_unit.py
@@ -5,14 +5,13 @@
import pytest
from onetl.base import BaseFileDFConnection
+from onetl.connection import SparkHDFS
from onetl.hooks import hook
pytestmark = [pytest.mark.hdfs, pytest.mark.file_df_connection, pytest.mark.connection]
-def test_spark_hdfs_connection_with_cluster(spark_mock):
- from onetl.connection import SparkHDFS
-
+def test_spark_hdfs_with_cluster(spark_mock):
hdfs = SparkHDFS(cluster="rnd-dwh", spark=spark_mock)
assert isinstance(hdfs, BaseFileDFConnection)
assert hdfs.cluster == "rnd-dwh"
@@ -21,9 +20,7 @@ def test_spark_hdfs_connection_with_cluster(spark_mock):
assert hdfs.instance_url == "rnd-dwh"
-def test_spark_hdfs_connection_with_cluster_and_host(spark_mock):
- from onetl.connection import SparkHDFS
-
+def test_spark_hdfs_with_cluster_and_host(spark_mock):
hdfs = SparkHDFS(cluster="rnd-dwh", host="some-host.domain.com", spark=spark_mock)
assert isinstance(hdfs, BaseFileDFConnection)
assert hdfs.cluster == "rnd-dwh"
@@ -31,9 +28,7 @@ def test_spark_hdfs_connection_with_cluster_and_host(spark_mock):
assert hdfs.instance_url == "rnd-dwh"
-def test_spark_hdfs_connection_with_port(spark_mock):
- from onetl.connection import SparkHDFS
-
+def test_spark_hdfs_with_port(spark_mock):
hdfs = SparkHDFS(cluster="rnd-dwh", port=9020, spark=spark_mock)
assert isinstance(hdfs, BaseFileDFConnection)
assert hdfs.cluster == "rnd-dwh"
@@ -41,9 +36,7 @@ def test_spark_hdfs_connection_with_port(spark_mock):
assert hdfs.instance_url == "rnd-dwh"
-def test_spark_hdfs_connection_without_cluster(spark_mock):
- from onetl.connection import SparkHDFS
-
+def test_spark_hdfs_without_cluster(spark_mock):
with pytest.raises(ValueError):
SparkHDFS(spark=spark_mock)
@@ -51,9 +44,13 @@ def test_spark_hdfs_connection_without_cluster(spark_mock):
SparkHDFS(host="some", spark=spark_mock)
-def test_spark_hdfs_get_known_clusters_hook(request, spark_mock):
- from onetl.connection import SparkHDFS
+def test_spark_hdfs_spark_stopped(spark_stopped):
+ msg = "Spark session is stopped. Please recreate Spark session."
+ with pytest.raises(ValueError, match=msg):
+ SparkHDFS(cluster="rnd-dwh", host="some-host.domain.com", spark=spark_stopped)
+
+def test_spark_hdfs_get_known_clusters_hook(request, spark_mock):
@SparkHDFS.Slots.get_known_clusters.bind
@hook
def get_known_clusters() -> set[str]:
@@ -71,8 +68,6 @@ def get_known_clusters() -> set[str]:
def test_spark_hdfs_known_normalize_cluster_name_hook(request, spark_mock):
- from onetl.connection import SparkHDFS
-
@SparkHDFS.Slots.normalize_cluster_name.bind
@hook
def normalize_cluster_name(cluster: str) -> str:
@@ -86,8 +81,6 @@ def normalize_cluster_name(cluster: str) -> str:
def test_spark_hdfs_get_cluster_namenodes_hook(request, spark_mock):
- from onetl.connection import SparkHDFS
-
@SparkHDFS.Slots.get_cluster_namenodes.bind
@hook
def get_cluster_namenodes(cluster: str) -> set[str]:
@@ -106,8 +99,6 @@ def get_cluster_namenodes(cluster: str) -> set[str]:
def test_spark_hdfs_normalize_namenode_host_hook(request, spark_mock):
- from onetl.connection import SparkHDFS
-
@SparkHDFS.Slots.normalize_namenode_host.bind
@hook
def normalize_namenode_host(host: str, cluster: str) -> str:
@@ -124,8 +115,6 @@ def normalize_namenode_host(host: str, cluster: str) -> str:
def test_spark_hdfs_get_ipc_port_hook(request, spark_mock):
- from onetl.connection import SparkHDFS
-
@SparkHDFS.Slots.get_ipc_port.bind
@hook
def get_ipc_port(cluster: str) -> int | None:
@@ -140,8 +129,6 @@ def get_ipc_port(cluster: str) -> int | None:
def test_spark_hdfs_known_get_current(request, spark_mock):
- from onetl.connection import SparkHDFS
-
# no hooks bound to SparkHDFS.Slots.get_current_cluster
error_msg = re.escape(
"SparkHDFS.get_current() can be used only if there are some hooks bound to SparkHDFS.Slots.get_current_cluster",
diff --git a/tests/tests_unit/tests_file_df_connection_unit/test_spark_local_fs_unit.py b/tests/tests_unit/tests_file_df_connection_unit/test_spark_local_fs_unit.py
index 8c8c8f377..e98c986cf 100644
--- a/tests/tests_unit/tests_file_df_connection_unit/test_spark_local_fs_unit.py
+++ b/tests/tests_unit/tests_file_df_connection_unit/test_spark_local_fs_unit.py
@@ -23,3 +23,9 @@ def test_spark_local_fs_spark_non_local(spark_mock, master):
msg = re.escape("Currently supports only spark.master='local'")
with pytest.raises(ValueError, match=msg):
SparkLocalFS(spark=spark_mock)
+
+
+def test_spark_local_fs_spark_stopped(spark_stopped):
+ msg = "Spark session is stopped. Please recreate Spark session."
+ with pytest.raises(ValueError, match=msg):
+ SparkLocalFS(spark=spark_stopped)
diff --git a/tests/tests_unit/tests_file_df_connection_unit/test_spark_s3_unit.py b/tests/tests_unit/tests_file_df_connection_unit/test_spark_s3_unit.py
index 4e84ef0c3..99a20633c 100644
--- a/tests/tests_unit/tests_file_df_connection_unit/test_spark_s3_unit.py
+++ b/tests/tests_unit/tests_file_df_connection_unit/test_spark_s3_unit.py
@@ -10,9 +10,9 @@
@pytest.mark.parametrize(
"spark_version, scala_version, package",
[
- ("3.4.1", None, "org.apache.spark:spark-hadoop-cloud_2.12:3.4.1"),
- ("3.4.1", "2.12", "org.apache.spark:spark-hadoop-cloud_2.12:3.4.1"),
- ("3.4.1", "2.13", "org.apache.spark:spark-hadoop-cloud_2.13:3.4.1"),
+ ("3.5.0", None, "org.apache.spark:spark-hadoop-cloud_2.12:3.5.0"),
+ ("3.5.0", "2.12", "org.apache.spark:spark-hadoop-cloud_2.12:3.5.0"),
+ ("3.5.0", "2.13", "org.apache.spark:spark-hadoop-cloud_2.13:3.5.0"),
],
)
def test_spark_s3_get_packages(spark_version, scala_version, package):
@@ -32,7 +32,7 @@ def test_spark_s3_get_packages_spark_2_error(spark_version):
@pytest.mark.parametrize("hadoop_version", ["2.7.3", "2.8.0", "2.10.1"])
-def test_spark_s3_connection_with_hadoop_2_error(spark_mock, hadoop_version):
+def test_spark_s3_with_hadoop_2_error(spark_mock, hadoop_version):
spark_mock._jvm = Mock()
spark_mock._jvm.org.apache.hadoop.util.VersionInfo.getVersion = Mock(return_value=hadoop_version)
@@ -47,7 +47,7 @@ def test_spark_s3_connection_with_hadoop_2_error(spark_mock, hadoop_version):
)
-def test_spark_s3_connection_missing_package(spark_no_packages):
+def test_spark_s3_missing_package(spark_no_packages):
spark_no_packages._jvm = Mock()
spark_no_packages._jvm.org.apache.hadoop.util.VersionInfo.getVersion = Mock(return_value="3.3.6")
@@ -63,6 +63,19 @@ def test_spark_s3_connection_missing_package(spark_no_packages):
)
+def test_spark_s3_spark_stopped(spark_stopped):
+ msg = "Spark session is stopped. Please recreate Spark session."
+ with pytest.raises(ValueError, match=msg):
+ SparkS3(
+ host="some_host",
+ access_key="access_key",
+ secret_key="some key",
+ session_token="some token",
+ bucket="bucket",
+ spark=spark_stopped,
+ )
+
+
@pytest.fixture()
def spark_mock_hadoop_3(spark_mock):
spark_mock._jvm = Mock()
@@ -70,7 +83,7 @@ def spark_mock_hadoop_3(spark_mock):
return spark_mock
-def test_spark_s3_connection(spark_mock_hadoop_3):
+def test_spark_s3(spark_mock_hadoop_3):
s3 = SparkS3(
host="some_host",
access_key="access key",
@@ -91,7 +104,7 @@ def test_spark_s3_connection(spark_mock_hadoop_3):
assert "some key" not in repr(s3)
-def test_spark_s3_connection_with_protocol_https(spark_mock_hadoop_3):
+def test_spark_s3_with_protocol_https(spark_mock_hadoop_3):
s3 = SparkS3(
host="some_host",
access_key="access_key",
@@ -106,7 +119,7 @@ def test_spark_s3_connection_with_protocol_https(spark_mock_hadoop_3):
assert s3.instance_url == "s3://some_host:443"
-def test_spark_s3_connection_with_protocol_http(spark_mock_hadoop_3):
+def test_spark_s3_with_protocol_http(spark_mock_hadoop_3):
s3 = SparkS3(
host="some_host",
access_key="access_key",
@@ -122,7 +135,7 @@ def test_spark_s3_connection_with_protocol_http(spark_mock_hadoop_3):
@pytest.mark.parametrize("protocol", ["http", "https"])
-def test_spark_s3_connection_with_port(spark_mock_hadoop_3, protocol):
+def test_spark_s3_with_port(spark_mock_hadoop_3, protocol):
s3 = SparkS3(
host="some_host",
port=9000,