Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[DOP-9645] - Add XML file format #163

Merged
merged 15 commits into from
Oct 10, 2023
Merged
Show file tree
Hide file tree
Changes from 5 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions docs/changelog/next_release/163.feature.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Add ``XML`` file format support.
1 change: 1 addition & 0 deletions docs/file_df/file_formats/index.rst
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@ File Formats
jsonline
orc
parquet
xml

.. toctree::
:maxdepth: 1
Expand Down
9 changes: 9 additions & 0 deletions docs/file_df/file_formats/xml.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
.. _xml-file-format:

XML
=====

.. currentmodule:: onetl.file.format.xml

.. autoclass:: XML
:members: get_packages
1 change: 1 addition & 0 deletions onetl/file/format/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,3 +20,4 @@
from onetl.file.format.jsonline import JSONLine
from onetl.file.format.orc import ORC
from onetl.file.format.parquet import Parquet
from onetl.file.format.xml import XML
224 changes: 224 additions & 0 deletions onetl/file/format/xml.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,224 @@
# Copyright 2023 MTS (Mobile Telesystems)
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from __future__ import annotations

import logging
from typing import TYPE_CHECKING, ClassVar

from pydantic import Field

from onetl._util.java import try_import_java_class
from onetl._util.scala import get_default_scala_version
from onetl._util.spark import get_spark_version
from onetl._util.version import Version
from onetl.exception import MISSING_JVM_CLASS_MSG
from onetl.file.format.file_format import ReadWriteFileFormat
from onetl.hooks import slot, support_hooks

if TYPE_CHECKING:
from pyspark.sql import SparkSession


PROHIBITED_OPTIONS = frozenset(
(
# filled by onETL classes
"path",
),
)


READ_OPTIONS = frozenset(
(
"rowTag",
"samplingRatio",
"excludeAttribute",
"treatEmptyValuesAsNulls",
"mode",
"inferSchema",
"columnNameOfCorruptRecord",
"attributePrefix",
"valueTag",
"charset",
"ignoreSurroundingSpaces",
"wildcardColName",
"rowValidationXSDPath",
"ignoreNamespace",
"timestampFormat",
"dateFormat",
),
)

WRITE_OPTIONS = frozenset(
(
"rowTag",
"rootTag",
"declaration",
"arrayElementName",
"nullValue",
"attributePrefix",
"valueTag",
"compression",
"timestampFormat",
"dateFormat",
),
)


log = logging.getLogger(__name__)


@support_hooks
class XML(ReadWriteFileFormat):
"""
XML file format. |support_hooks|

Based on `Databricks Spark XML <https://github.com/databricks/spark-xml>`_ file format.

Supports reading/writing files with ``.xml`` extension.
maxim-lixakov marked this conversation as resolved.
Show resolved Hide resolved

.. versionadded:: 0.9.5

.. dropdown:: Version compatibility

* Spark versions: 3.2.x - 3.4.x.

* Scala versions: 2.12 - 2.13

* Java versions: 8 - 20


See documentation from link above.

.. note ::

You can pass any option to the constructor, even if it is not mentioned in this documentation.
**Option names should be in** ``camelCase``!

The set of supported options depends on Spark version. See link above.

Examples
--------
Describe options how to read from/write to XML file with specific options:

.. code:: python

from onetl.file.format import XML
from pyspark.sql import SparkSession

# Create Spark session with XML package loaded
maven_packages = XML.get_packages(spark_version="3.4.1")
spark = (
SparkSession.builder.appName("spark-app-name")
.config("spark.jars.packages", ",".join(maven_packages))
.getOrCreate()
)

xml = XML(row_tag="item")

"""

name: ClassVar[str] = "xml"

row_tag: str = Field(alias="rowTag")

class Config:
known_options = READ_OPTIONS | WRITE_OPTIONS
maxim-lixakov marked this conversation as resolved.
Show resolved Hide resolved
prohibited_options = PROHIBITED_OPTIONS
extra = "allow"

@slot
@classmethod
def get_packages( # noqa: WPS231
cls,
spark_version: str,
scala_version: str | None = None,
package_version: str | None = None,
) -> list[str]:
"""
Get package names to be downloaded by Spark. |support_hooks|

Parameters
----------
spark_version : str
Spark version in format ``major.minor.patch``.

scala_version : str, optional
Scala version in format ``major.minor``.

If ``None``, ``spark_version`` is used to determine Scala version.

version: str, optional
Package version in format ``major.minor.patch``. Default is ``0.17.0``.

.. warning::

Version ``0.13`` and below are not supported.
maxim-lixakov marked this conversation as resolved.
Show resolved Hide resolved

.. note::

It is not guaranteed that custom package versions are supported.
Tests are performed only for default version.

Examples
--------

.. code:: python

from onetl.file.format import XML

XML.get_packages(spark_version="3.4.1")
XML.get_packages(spark_version="3.4.1", scala_version="2.12")
XML.get_packages(
spark_version="3.4.1",
scala_version="2.12",
package_version="0.17.0",
)

"""

if package_version:
version = Version.parse(package_version)
log.warning("Passed custom package version %r, it is not guaranteed to be supported", package_version)
else:
version = Version.parse("0.17.0")

spark_ver = Version.parse(spark_version)
scala_ver = Version.parse(scala_version) if scala_version else get_default_scala_version(spark_ver)

# Ensure compatibility with Spark and Scala versions
if spark_ver < (3, 0):
raise ValueError(f"Spark version must be 3.x, got {spark_ver}")

if scala_ver < (2, 12) or scala_ver > (2, 13):
raise ValueError(f"Scala version must be 2.12 or 2.13, got {scala_ver}")

Check warning on line 205 in onetl/file/format/xml.py

View check run for this annotation

Codecov / codecov/patch

onetl/file/format/xml.py#L205

Added line #L205 was not covered by tests

return [f"com.databricks:spark-xml_{scala_ver.digits(2)}:{version.digits(3)}"]

@slot
def check_if_supported(self, spark: SparkSession) -> None:
java_class = "com.databricks.spark.xml.XmlReader"

try:
try_import_java_class(spark, java_class)
except Exception as e:
spark_version = get_spark_version(spark)
msg = MISSING_JVM_CLASS_MSG.format(
java_class=java_class,
package_source=self.__class__.__name__,
args=f"spark_version='{spark_version}'",
)
if log.isEnabledFor(logging.DEBUG):
log.debug("Missing Java class", exc_info=e, stack_info=True)

Check warning on line 223 in onetl/file/format/xml.py

View check run for this annotation

Codecov / codecov/patch

onetl/file/format/xml.py#L223

Added line #L223 was not covered by tests
raise ValueError(msg) from e
63 changes: 56 additions & 7 deletions tests/fixtures/connections/file_df_connections.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,13 +56,62 @@ def file_df_schema_str_value_last():
@pytest.fixture()
def file_df_dataframe(spark, file_df_schema):
data = [
[1, "val1", 123, datetime.date(2021, 1, 1), datetime.datetime(2021, 1, 1, 1, 1, 1), 1.23],
[2, "val1", 234, datetime.date(2022, 2, 2), datetime.datetime(2022, 2, 2, 2, 2, 2), 2.34],
[3, "val2", 345, datetime.date(2023, 3, 3), datetime.datetime(2023, 3, 3, 3, 3, 3), 3.45],
[4, "val2", 456, datetime.date(2024, 4, 4), datetime.datetime(2024, 4, 4, 4, 4, 4), 4.56],
[5, "val3", 567, datetime.date(2025, 5, 5), datetime.datetime(2025, 5, 5, 5, 5, 5), 5.67],
[6, "val3", 678, datetime.date(2026, 6, 6), datetime.datetime(2026, 6, 6, 6, 6, 6), 6.78],
[7, "val3", 789, datetime.date(2027, 7, 7), datetime.datetime(2027, 7, 7, 7, 7, 7), 7.89],
[
1,
"val1",
123,
datetime.date(2021, 1, 1),
datetime.datetime(2021, 1, 1, 1, 1, 1, tzinfo=datetime.timezone.utc),
1.23,
],
[
2,
"val1",
234,
datetime.date(2022, 2, 2),
datetime.datetime(2022, 2, 2, 2, 2, 2, tzinfo=datetime.timezone.utc),
2.34,
],
[
3,
"val2",
345,
datetime.date(2023, 3, 3),
datetime.datetime(2023, 3, 3, 3, 3, 3, tzinfo=datetime.timezone.utc),
3.45,
],
[
4,
"val2",
456,
datetime.date(2024, 4, 4),
datetime.datetime(2024, 4, 4, 4, 4, 4, tzinfo=datetime.timezone.utc),
4.56,
],
[
5,
"val3",
567,
datetime.date(2025, 5, 5),
datetime.datetime(2025, 5, 5, 5, 5, 5, tzinfo=datetime.timezone.utc),
5.67,
],
[
6,
"val3",
678,
datetime.date(2026, 6, 6),
datetime.datetime(2026, 6, 6, 6, 6, 6, tzinfo=datetime.timezone.utc),
6.78,
],
[
7,
"val3",
789,
datetime.date(2027, 7, 7),
datetime.datetime(2027, 7, 7, 7, 7, 7, tzinfo=datetime.timezone.utc),
7.89,
],
]
return spark.createDataFrame(data, schema=file_df_schema)

Expand Down
5 changes: 4 additions & 1 deletion tests/fixtures/spark.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@ def maven_packages():
SparkS3,
Teradata,
)
from onetl.file.format import Avro, Excel
from onetl.file.format import XML, Avro, Excel

pyspark_version = get_pyspark_version()
packages = (
Expand All @@ -71,6 +71,9 @@ def maven_packages():
# There is no SparkS3 connector for Spark less than 3
packages.extend(SparkS3.get_packages(spark_version=pyspark_version))

# There is no XML files support for Spark less than 3
packages.extend(XML.get_packages(pyspark_version))

# There is no MongoDB connector for Spark less than 3.2
packages.extend(MongoDB.get_packages(spark_version=pyspark_version))

Expand Down
Loading
Loading