From 2caf92f22d046d2285b3b89ad255647879a5ef20 Mon Sep 17 00:00:00 2001 From: Maxim Martynov Date: Thu, 13 Jul 2023 17:04:19 +0300 Subject: [PATCH] [DOP-6705] Add CSV file format (#69) --- docs/file/file_filters/base.rst | 2 +- docs/file/file_formats/base.rst | 9 ++ docs/file/file_formats/csv.rst | 9 ++ docs/file/file_formats/index.rst | 16 +++ docs/file/index.rst | 1 + onetl/base/base_file_format.py | 6 +- onetl/connection/db_connection/mongodb.py | 8 +- onetl/file/format/__init__.py | 16 +++ onetl/file/format/csv.py | 103 ++++++++++++++++++ onetl/file/format/file_format.py | 53 +++++++++ onetl/impl/generic_options.py | 9 +- .../test_format_unit/test_csv_unit.py | 82 ++++++++++++++ .../test_format_unit/test_file_format_unit.py | 26 +++++ 13 files changed, 328 insertions(+), 12 deletions(-) create mode 100644 docs/file/file_formats/base.rst create mode 100644 docs/file/file_formats/csv.rst create mode 100644 docs/file/file_formats/index.rst create mode 100644 onetl/file/format/__init__.py create mode 100644 onetl/file/format/csv.py create mode 100644 onetl/file/format/file_format.py create mode 100644 tests/tests_unit/test_file/test_format_unit/test_csv_unit.py create mode 100644 tests/tests_unit/test_file/test_format_unit/test_file_format_unit.py diff --git a/docs/file/file_filters/base.rst b/docs/file/file_filters/base.rst index 2d266009e..1ba9007d6 100644 --- a/docs/file/file_filters/base.rst +++ b/docs/file/file_filters/base.rst @@ -1,4 +1,4 @@ -.. _base-filter: +.. _base-file-filter: Base interface ============== diff --git a/docs/file/file_formats/base.rst b/docs/file/file_formats/base.rst new file mode 100644 index 000000000..5b6d93d8a --- /dev/null +++ b/docs/file/file_formats/base.rst @@ -0,0 +1,9 @@ +.. _base-file-format: + +Base interface +============== + +.. currentmodule:: onetl.base.base_file_format + +.. autoclass:: BaseFileFormat + :members: check_if_supported, apply_to_reader, apply_to_writer diff --git a/docs/file/file_formats/csv.rst b/docs/file/file_formats/csv.rst new file mode 100644 index 000000000..44201e71a --- /dev/null +++ b/docs/file/file_formats/csv.rst @@ -0,0 +1,9 @@ +.. _csv-file-format: + +CSV +=== + +.. currentmodule:: onetl.file.format.csv + +.. autoclass:: CSV + :members: __init__ diff --git a/docs/file/file_formats/index.rst b/docs/file/file_formats/index.rst new file mode 100644 index 000000000..3611d7b79 --- /dev/null +++ b/docs/file/file_formats/index.rst @@ -0,0 +1,16 @@ +.. _file-formats: + +File Formats +============ + +.. toctree:: + :maxdepth: 1 + :caption: File formats + + csv + +.. toctree:: + :maxdepth: 1 + :caption: For developers + + base diff --git a/docs/file/index.rst b/docs/file/index.rst index df1b6a449..3b11c8210 100644 --- a/docs/file/index.rst +++ b/docs/file/index.rst @@ -9,3 +9,4 @@ file_mover/index file_filters/index file_limits/index + file_formats/index diff --git a/onetl/base/base_file_format.py b/onetl/base/base_file_format.py index f76475dce..e354aeaaa 100644 --- a/onetl/base/base_file_format.py +++ b/onetl/base/base_file_format.py @@ -28,14 +28,14 @@ class BaseFileFormat(ABC): @classmethod @abstractmethod - def check_if_available(cls, spark: SparkSession) -> None: + def check_if_supported(cls, spark: SparkSession) -> None: """ - Check if file format is available. |support_hooks| + Check if Spark session does support this file format. |support_hooks| Raises ------- RuntimeError - If file format is not available. + If file format is not supported. """ @abstractmethod diff --git a/onetl/connection/db_connection/mongodb.py b/onetl/connection/db_connection/mongodb.py index 172a0371c..5f355495f 100644 --- a/onetl/connection/db_connection/mongodb.py +++ b/onetl/connection/db_connection/mongodb.py @@ -795,12 +795,12 @@ def write_df_to_target( self._check_driver_imported() write_options = self.WriteOptions.parse(options) mode = write_options.mode - write_options = write_options.dict(by_alias=True, exclude_none=True, exclude={"mode"}) - write_options["connection.uri"] = self.connection_url - write_options["collection"] = target + write_options_dict = write_options.dict(by_alias=True, exclude_none=True, exclude={"mode"}) + write_options_dict["connection.uri"] = self.connection_url + write_options_dict["collection"] = target log.info("|%s| Saving data to a collection %r", self.__class__.__name__, target) - df.write.format("mongodb").mode(mode).options(**write_options).save() + df.write.format("mongodb").mode(mode).options(**write_options_dict).save() log.info("|%s| Collection %r is successfully written", self.__class__.__name__, target) @property diff --git a/onetl/file/format/__init__.py b/onetl/file/format/__init__.py new file mode 100644 index 000000000..abe8ed287 --- /dev/null +++ b/onetl/file/format/__init__.py @@ -0,0 +1,16 @@ +# Copyright 2023 MTS (Mobile Telesystems) +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +from onetl.file.format.csv import CSV diff --git a/onetl/file/format/csv.py b/onetl/file/format/csv.py new file mode 100644 index 000000000..c9247f5d8 --- /dev/null +++ b/onetl/file/format/csv.py @@ -0,0 +1,103 @@ +# Copyright 2023 MTS (Mobile Telesystems) +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +from typing import TYPE_CHECKING, ClassVar + +from pydantic import Field + +from onetl.file.format.file_format import FileFormat +from onetl.hooks import slot, support_hooks + +if TYPE_CHECKING: + from pyspark.sql import SparkSession + + +READ_WRITE_OPTIONS = { + "charToEscapeQuoteEscaping", + "dateFormat", + "emptyValue", + "ignoreLeadingWhiteSpace", + "ignoreTrailingWhiteSpace", + "nullValue", + "timestampFormat", + "timestampNTZFormat", +} + +READ_OPTIONS = { + "columnNameOfCorruptRecord", + "comment", + "enableDateTimeParsingFallback", + "enforceSchema", + "inferSchema", + "locale", + "maxCharsPerColumn", + "maxColumns", + "mode", + "multiLine", + "nanValue", + "negativeInf", + "positiveInf", + "preferDate", + "samplingRatio", + "unescapedQuoteHandling", +} + +WRITE_OPTIONS = { + "compression", + "escapeQuotes", + "quoteAll", +} + + +@support_hooks +class CSV(FileFormat): + """ + CSV file format. |support_hooks| + + Based on `Spark CSV Files `_ file format. + + .. note :: + + You can pass any option to the constructor, even if it is not mentioned in this documentation. + + Examples + -------- + + Describe options how to read from/write to CSV file with specific options: + + .. code:: python + + csv = CSV(sep=",", encoding="utf-8", inferSchema=True, compression="gzip") + + """ + + name: ClassVar[str] = "csv" + delimiter: str = Field(default=",", alias="sep") + encoding: str = "utf-8" + quote: str = '"' + escape: str = "\\" + header: bool = False + lineSep: str = "\n" # noqa: N815 + + class Config: + known_options = READ_WRITE_OPTIONS | READ_OPTIONS | WRITE_OPTIONS + extra = "allow" + + @slot + @classmethod + def check_if_supported(cls, spark: SparkSession) -> None: + # always available + pass diff --git a/onetl/file/format/file_format.py b/onetl/file/format/file_format.py new file mode 100644 index 000000000..6474420c5 --- /dev/null +++ b/onetl/file/format/file_format.py @@ -0,0 +1,53 @@ +# Copyright 2023 MTS (Mobile Telesystems) +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +from typing import TYPE_CHECKING, ClassVar, TypeVar + +from onetl.base import BaseFileFormat +from onetl.hooks import slot, support_hooks +from onetl.impl import GenericOptions + +if TYPE_CHECKING: + from pyspark.sql import DataFrameReader, DataFrameWriter + +T = TypeVar("T") + +PROHIBITED_OPTIONS = { + "ignoreCorruptFiles", + "ignoreMissingFiles", + "modifiedAfter", + "modifiedBefore", + "pathGlobFilter", + "recursiveFileLookup", +} + + +@support_hooks +class FileFormat(BaseFileFormat, GenericOptions): + name: ClassVar[str] + + class Config: + prohibited_options = PROHIBITED_OPTIONS + + @slot + def apply_to_reader(self, reader: DataFrameReader) -> DataFrameReader: + options = self.dict(by_alias=True) + return reader.format(self.name).options(**options) + + @slot + def apply_to_writer(self, writer: DataFrameWriter) -> DataFrameWriter: + options = self.dict(by_alias=True) + return writer.format(self.name).options(**options) diff --git a/onetl/impl/generic_options.py b/onetl/impl/generic_options.py index fca5929f4..064780534 100644 --- a/onetl/impl/generic_options.py +++ b/onetl/impl/generic_options.py @@ -16,13 +16,14 @@ import logging from fnmatch import fnmatch -from typing import Iterable +from typing import Iterable, TypeVar from pydantic import root_validator from onetl.impl.frozen_model import FrozenModel log = logging.getLogger(__name__) +T = TypeVar("T", bound="GenericOptions") class GenericOptions(FrozenModel): @@ -32,9 +33,9 @@ class Config: @classmethod def parse( - cls, + cls: type[T], options: GenericOptions | dict | None, - ): + ) -> T: """ If a parameter inherited from the ReadOptions class was passed, then it will be returned unchanged. If a Dict object was passed it will be converted to ReadOptions. @@ -56,7 +57,7 @@ def parse( return options @root_validator - def check_options_not_prohibited( + def check_options_allowed( cls, values, ) -> None: diff --git a/tests/tests_unit/test_file/test_format_unit/test_csv_unit.py b/tests/tests_unit/test_file/test_format_unit/test_csv_unit.py new file mode 100644 index 000000000..b68d44348 --- /dev/null +++ b/tests/tests_unit/test_file/test_format_unit/test_csv_unit.py @@ -0,0 +1,82 @@ +import logging + +import pytest + +from onetl.file.format import CSV + + +def test_csv_default_options(): + csv = CSV() + assert csv.delimiter == "," + assert csv.encoding == "utf-8" + assert csv.quote == '"' + assert csv.escape == "\\" + assert csv.header is False + assert csv.lineSep == "\n" + + +def test_csv_default_options_override(): + csv = CSV( + delimiter="value", + encoding="value", + quote="value", + escape="value", + header=True, + lineSep="value", + ) + assert csv.delimiter == "value" + assert csv.encoding == "value" + assert csv.quote == "value" + assert csv.escape == "value" + assert csv.header is True + assert csv.lineSep == "value" + + +def test_csv_delimiter_alias(): + csv = CSV(sep="value") + assert csv.delimiter == "value" + + +@pytest.mark.parametrize( + "known_option", + [ + "charToEscapeQuoteEscaping", + "dateFormat", + "emptyValue", + "ignoreLeadingWhiteSpace", + "ignoreTrailingWhiteSpace", + "nullValue", + "timestampFormat", + "timestampNTZFormat", + "columnNameOfCorruptRecord", + "comment", + "enableDateTimeParsingFallback", + "enforceSchema", + "inferSchema", + "locale", + "maxCharsPerColumn", + "maxColumns", + "mode", + "multiLine", + "nanValue", + "negativeInf", + "positiveInf", + "preferDate", + "samplingRatio", + "unescapedQuoteHandling", + "compression", + "escapeQuotes", + "quoteAll", + ], +) +def test_csv_known_options(known_option): + csv = CSV(**{known_option: "value"}) + assert getattr(csv, known_option) == "value" + + +def test_csv_unknown_options(caplog): + with caplog.at_level(logging.WARNING): + csv = CSV(unknown="abc") + assert csv.unknown == "abc" + + assert ("Options ['unknown'] are not known by CSV, are you sure they are valid?") in caplog.text diff --git a/tests/tests_unit/test_file/test_format_unit/test_file_format_unit.py b/tests/tests_unit/test_file/test_format_unit/test_file_format_unit.py new file mode 100644 index 000000000..750d35275 --- /dev/null +++ b/tests/tests_unit/test_file/test_format_unit/test_file_format_unit.py @@ -0,0 +1,26 @@ +import pytest + +from onetl.file.format import CSV + + +@pytest.mark.parametrize( + "prohibited_option", + [ + "ignoreCorruptFiles", + "ignoreMissingFiles", + "modifiedAfter", + "modifiedBefore", + "pathGlobFilter", + "recursiveFileLookup", + ], +) +@pytest.mark.parametrize( + "format_class", + [ + CSV, + ], +) +def test_file_format_prohibited_options(prohibited_option, format_class): + msg = rf"Options \['{prohibited_option}'\] are not allowed to use in a {format_class.__name__}" + with pytest.raises(ValueError, match=msg): + format_class(**{prohibited_option: "value"})