diff --git a/docs/changelog/next_release/269.feature.rst b/docs/changelog/next_release/269.feature.rst new file mode 100644 index 000000000..53bb70363 --- /dev/null +++ b/docs/changelog/next_release/269.feature.rst @@ -0,0 +1 @@ +Add ``XML.parse_column`` method for handling XML data within Spark. This method allows for direct parsing of XML strings into structured Spark DataFrame columns. diff --git a/docs/connection/db_connection/kafka/format_handling.rst b/docs/connection/db_connection/kafka/format_handling.rst index bc1993a20..5f2d00864 100644 --- a/docs/connection/db_connection/kafka/format_handling.rst +++ b/docs/connection/db_connection/kafka/format_handling.rst @@ -242,3 +242,56 @@ To serialize structured data into Avro format and write it back to a Kafka topic # | 1|[02 02 02 08 76 6... (binary data)] | # | 2|[02 04 02 08 76 6... (binary data)] | # +---+------------------------------------+ + +XML Format Handling +------------------- + +Handling XML data in Kafka involves parsing string representations of XML into structured Spark DataFrame format. + +``DBReader`` +~~~~~~~~~~~~ + +To process XML formatted data from Kafka, use the :obj:`XML.parse_column ` method. This method allows you to convert a column containing XML strings directly into a structured Spark DataFrame using a specified schema. + +.. code-block:: python + + from pyspark.sql import SparkSession + from pyspark.sql.types import StructType, StructField, StringType, IntegerType + + from onetl.db import DBReader + from onetl.file.format import XML + from onetl.connection import Kafka + + spark = SparkSession.builder.appName("KafkaXMLExample").getOrCreate() + + kafka = Kafka(...) + xml = XML(row_tag="person") + + reader = DBReader( + connection=kafka, + topic="topic_name", + ) + df = reader.run() + + df.show() + # +----+--------------------------------------------------------------------------------------------+----------+---------+------+-----------------------+-------------+ + # |key |value |topic |partition|offset|timestamp |timestampType| + # +----+--------------------------------------------------------------------------------------------+----------+---------+------+-----------------------+-------------+ + # |[31]|"Alice20" |topicXML |0 |0 |2024-04-24 13:02:25.911|0 | + # |[32]|"Bob25" |topicXML |0 |1 |2024-04-24 13:02:25.922|0 | + # +----+--------------------------------------------------------------------------------------------+----------+---------+------+-----------------------+-------------+ + + xml_schema = StructType( + [ + StructField("name", StringType(), nullable=True), + StructField("age", IntegerType(), nullable=True), + ] + ) + parsed_xml_df = df.select(xml.parse_column("value", xml_schema)) + parsed_xml_df.show() + # +-----------+ + # |value | + # +-----------+ + # |{Alice, 20}| + # |{Bob, 25} | + # +-----------+ diff --git a/docs/file_df/file_formats/xml.rst b/docs/file_df/file_formats/xml.rst index 187aa89a4..cfe560ccc 100644 --- a/docs/file_df/file_formats/xml.rst +++ b/docs/file_df/file_formats/xml.rst @@ -6,4 +6,4 @@ XML .. currentmodule:: onetl.file.format.xml .. autoclass:: XML - :members: get_packages + :members: get_packages, parse_column diff --git a/onetl/file/format/xml.py b/onetl/file/format/xml.py index 83c02329b..f1dc337b3 100644 --- a/onetl/file/format/xml.py +++ b/onetl/file/format/xml.py @@ -19,7 +19,8 @@ from onetl.hooks import slot, support_hooks if TYPE_CHECKING: - from pyspark.sql import SparkSession + from pyspark.sql import Column, SparkSession + from pyspark.sql.types import StructType PROHIBITED_OPTIONS = frozenset( @@ -226,3 +227,113 @@ def check_if_supported(self, spark: SparkSession) -> None: if log.isEnabledFor(logging.DEBUG): log.debug("Missing Java class", exc_info=e, stack_info=True) raise ValueError(msg) from e + + def parse_column(self, column: str | Column, schema: StructType) -> Column: + """ + Parses an XML string column into a structured Spark SQL column using the ``from_xml`` function + provided by the `Databricks Spark XML library `_ + based on the provided schema. + + .. note:: + + This method assumes that the ``spark-xml`` package is installed: :obj:`XML.get_packages `. + + .. note:: + + This method parses each DataFrame row individually. Therefore, for a specific column, each row must contain exactly one occurrence of the ``rowTag`` specified. If your XML data includes a root tag that encapsulates multiple row tags, you can adjust the schema to use an ``ArrayType`` to keep all child elements under the single root. + + .. code-block:: xml + + + Book OneAuthor A + Book TwoAuthor B + + + And the corresponding schema in Spark using an ``ArrayType``: + + .. code-block:: python + + from pyspark.sql.types import StructType, StructField, ArrayType, StringType + + schema = StructType( + [ + StructField( + "book", + ArrayType( + StructType( + [ + StructField("title", StringType(), True), + StructField("author", StringType(), True), + ] + ) + ), + True, + ) + ] + ) + + Parameters + ---------- + column : str | Column + The name of the column or the Column object containing XML strings to parse. + + Returns + ------- + Column + A new Column object with data parsed from XML string to the specified structured format. + + Examples + -------- + .. code-block:: python + + from pyspark.sql import SparkSession + from pyspark.sql.types import StructType, StructField, StringType, IntegerType + + from onetl.file.format import XML + + spark = SparkSession.builder.appName("XMLParsingExample").getOrCreate() + schema = StructType( + [ + StructField("author", StringType(), nullable=True), + StructField("title", StringType(), nullable=True), + StructField("genre", StringType(), nullable=True), + StructField("price", IntegerType(), nullable=True), + ] + ) + xml_processor = XML(row_tag="book") + + data = [ + ( + "Austen, JanePride and Prejudiceromance19", + ) + ] + df = spark.createDataFrame(data, ["xml_string"]) + + parsed_df = df.select(xml_processor.parse_column("xml_string", schema=schema)) + parsed_df.show() + + """ + from pyspark.sql import Column, SparkSession # noqa: WPS442 + + spark = SparkSession._instantiatedSession # noqa: WPS437 + self.check_if_supported(spark) + + from pyspark.sql.column import _to_java_column # noqa: WPS450 + from pyspark.sql.functions import col + + if isinstance(column, Column): + column_name, column = column._jc.toString(), column.cast("string") # noqa: WPS437 + else: + column_name, column = column, col(column).cast("string") + + java_column = _to_java_column(column) + java_schema = spark._jsparkSession.parseDataType(schema.json()) # noqa: WPS437 + scala_options = spark._jvm.org.apache.spark.api.python.PythonUtils.toScalaMap( # noqa: WPS219, WPS437 + self.dict(), + ) + jc = spark._jvm.com.databricks.spark.xml.functions.from_xml( # noqa: WPS219, WPS437 + java_column, + java_schema, + scala_options, + ) + return Column(jc).alias(column_name) diff --git a/tests/tests_integration/test_file_format_integration/test_xml_integration.py b/tests/tests_integration/test_file_format_integration/test_xml_integration.py index 705d7ff84..5ebaaf1ab 100644 --- a/tests/tests_integration/test_file_format_integration/test_xml_integration.py +++ b/tests/tests_integration/test_file_format_integration/test_xml_integration.py @@ -4,6 +4,8 @@ Do not test all the possible options and combinations, we are not testing Spark here. """ +import datetime + import pytest from onetl._util.spark import get_spark_version @@ -11,9 +13,12 @@ from onetl.file.format import XML try: + from pyspark.sql import Row + from pyspark.sql.functions import col + from tests.util.assert_df import assert_equal_df except ImportError: - pytest.skip("Missing pandas", allow_module_level=True) + pytest.skip("Missing pandas or pyspark", allow_module_level=True) pytestmark = [pytest.mark.local_fs, pytest.mark.file_df_connection, pytest.mark.connection, pytest.mark.xml] @@ -166,3 +171,45 @@ def test_xml_reader_with_attributes( assert read_df.count() assert read_df.schema == expected_xml_attributes_df.schema assert_equal_df(read_df, expected_xml_attributes_df, order_by="id") + + +@pytest.mark.parametrize( + "xml_input, expected_row", + [ + ( + """ + 1 + Alice + 123 + 2021-01-01 + 2021-01-01T07:01:01Z + 1.23 + """, + Row( + xml_string=Row( + id=1, + str_value="Alice", + int_value=123, + date_value=datetime.date(2021, 1, 1), + datetime_value=datetime.datetime(2021, 1, 1, 7, 1, 1), + float_value=1.23, + ), + ), + ), + ], + ids=["basic-case"], +) +@pytest.mark.parametrize("column_type", [str, col]) +def test_xml_parse_column(spark, xml_input: str, expected_row: Row, column_type, file_df_schema): + from onetl.file.format import XML + + spark_version = get_spark_version(spark) + if spark_version.major < 3: + pytest.skip("XML files are supported on Spark 3.x only") + + xml = XML(row_tag="item") + df = spark.createDataFrame([(xml_input,)], ["xml_string"]) + parsed_df = df.select(xml.parse_column(column_type("xml_string"), schema=file_df_schema)) + result_row = parsed_df.first() + + assert result_row == expected_row