Skip to content

Commit

Permalink
[DOP-6705] Add CSV file format (#69)
Browse files Browse the repository at this point in the history
  • Loading branch information
dolfinus authored Jul 13, 2023
1 parent 8e060bc commit 2caf92f
Show file tree
Hide file tree
Showing 13 changed files with 328 additions and 12 deletions.
2 changes: 1 addition & 1 deletion docs/file/file_filters/base.rst
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
.. _base-filter:
.. _base-file-filter:

Base interface
==============
Expand Down
9 changes: 9 additions & 0 deletions docs/file/file_formats/base.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
.. _base-file-format:

Base interface
==============

.. currentmodule:: onetl.base.base_file_format

.. autoclass:: BaseFileFormat
:members: check_if_supported, apply_to_reader, apply_to_writer
9 changes: 9 additions & 0 deletions docs/file/file_formats/csv.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
.. _csv-file-format:

CSV
===

.. currentmodule:: onetl.file.format.csv

.. autoclass:: CSV
:members: __init__
16 changes: 16 additions & 0 deletions docs/file/file_formats/index.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
.. _file-formats:

File Formats
============

.. toctree::
:maxdepth: 1
:caption: File formats

csv

.. toctree::
:maxdepth: 1
:caption: For developers

base
1 change: 1 addition & 0 deletions docs/file/index.rst
Original file line number Diff line number Diff line change
Expand Up @@ -9,3 +9,4 @@
file_mover/index
file_filters/index
file_limits/index
file_formats/index
6 changes: 3 additions & 3 deletions onetl/base/base_file_format.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,14 +28,14 @@ class BaseFileFormat(ABC):

@classmethod
@abstractmethod
def check_if_available(cls, spark: SparkSession) -> None:
def check_if_supported(cls, spark: SparkSession) -> None:
"""
Check if file format is available. |support_hooks|
Check if Spark session does support this file format. |support_hooks|
Raises
-------
RuntimeError
If file format is not available.
If file format is not supported.
"""

@abstractmethod
Expand Down
8 changes: 4 additions & 4 deletions onetl/connection/db_connection/mongodb.py
Original file line number Diff line number Diff line change
Expand Up @@ -795,12 +795,12 @@ def write_df_to_target(
self._check_driver_imported()
write_options = self.WriteOptions.parse(options)
mode = write_options.mode
write_options = write_options.dict(by_alias=True, exclude_none=True, exclude={"mode"})
write_options["connection.uri"] = self.connection_url
write_options["collection"] = target
write_options_dict = write_options.dict(by_alias=True, exclude_none=True, exclude={"mode"})
write_options_dict["connection.uri"] = self.connection_url
write_options_dict["collection"] = target

log.info("|%s| Saving data to a collection %r", self.__class__.__name__, target)
df.write.format("mongodb").mode(mode).options(**write_options).save()
df.write.format("mongodb").mode(mode).options(**write_options_dict).save()
log.info("|%s| Collection %r is successfully written", self.__class__.__name__, target)

@property
Expand Down
16 changes: 16 additions & 0 deletions onetl/file/format/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
# Copyright 2023 MTS (Mobile Telesystems)
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.


from onetl.file.format.csv import CSV
103 changes: 103 additions & 0 deletions onetl/file/format/csv.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,103 @@
# Copyright 2023 MTS (Mobile Telesystems)
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from __future__ import annotations

from typing import TYPE_CHECKING, ClassVar

from pydantic import Field

from onetl.file.format.file_format import FileFormat
from onetl.hooks import slot, support_hooks

if TYPE_CHECKING:
from pyspark.sql import SparkSession


READ_WRITE_OPTIONS = {
"charToEscapeQuoteEscaping",
"dateFormat",
"emptyValue",
"ignoreLeadingWhiteSpace",
"ignoreTrailingWhiteSpace",
"nullValue",
"timestampFormat",
"timestampNTZFormat",
}

READ_OPTIONS = {
"columnNameOfCorruptRecord",
"comment",
"enableDateTimeParsingFallback",
"enforceSchema",
"inferSchema",
"locale",
"maxCharsPerColumn",
"maxColumns",
"mode",
"multiLine",
"nanValue",
"negativeInf",
"positiveInf",
"preferDate",
"samplingRatio",
"unescapedQuoteHandling",
}

WRITE_OPTIONS = {
"compression",
"escapeQuotes",
"quoteAll",
}


@support_hooks
class CSV(FileFormat):
"""
CSV file format. |support_hooks|
Based on `Spark CSV Files <https://spark.apache.org/docs/latest/sql-data-sources-csv.html>`_ file format.
.. note ::
You can pass any option to the constructor, even if it is not mentioned in this documentation.
Examples
--------
Describe options how to read from/write to CSV file with specific options:
.. code:: python
csv = CSV(sep=",", encoding="utf-8", inferSchema=True, compression="gzip")
"""

name: ClassVar[str] = "csv"
delimiter: str = Field(default=",", alias="sep")
encoding: str = "utf-8"
quote: str = '"'
escape: str = "\\"
header: bool = False
lineSep: str = "\n" # noqa: N815

class Config:
known_options = READ_WRITE_OPTIONS | READ_OPTIONS | WRITE_OPTIONS
extra = "allow"

@slot
@classmethod
def check_if_supported(cls, spark: SparkSession) -> None:
# always available
pass
53 changes: 53 additions & 0 deletions onetl/file/format/file_format.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,53 @@
# Copyright 2023 MTS (Mobile Telesystems)
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from __future__ import annotations

from typing import TYPE_CHECKING, ClassVar, TypeVar

from onetl.base import BaseFileFormat
from onetl.hooks import slot, support_hooks
from onetl.impl import GenericOptions

if TYPE_CHECKING:
from pyspark.sql import DataFrameReader, DataFrameWriter

T = TypeVar("T")

PROHIBITED_OPTIONS = {
"ignoreCorruptFiles",
"ignoreMissingFiles",
"modifiedAfter",
"modifiedBefore",
"pathGlobFilter",
"recursiveFileLookup",
}


@support_hooks
class FileFormat(BaseFileFormat, GenericOptions):
name: ClassVar[str]

class Config:
prohibited_options = PROHIBITED_OPTIONS

@slot
def apply_to_reader(self, reader: DataFrameReader) -> DataFrameReader:
options = self.dict(by_alias=True)
return reader.format(self.name).options(**options)

@slot
def apply_to_writer(self, writer: DataFrameWriter) -> DataFrameWriter:
options = self.dict(by_alias=True)
return writer.format(self.name).options(**options)
9 changes: 5 additions & 4 deletions onetl/impl/generic_options.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,13 +16,14 @@

import logging
from fnmatch import fnmatch
from typing import Iterable
from typing import Iterable, TypeVar

from pydantic import root_validator

from onetl.impl.frozen_model import FrozenModel

log = logging.getLogger(__name__)
T = TypeVar("T", bound="GenericOptions")


class GenericOptions(FrozenModel):
Expand All @@ -32,9 +33,9 @@ class Config:

@classmethod
def parse(
cls,
cls: type[T],
options: GenericOptions | dict | None,
):
) -> T:
"""
If a parameter inherited from the ReadOptions class was passed, then it will be returned unchanged.
If a Dict object was passed it will be converted to ReadOptions.
Expand All @@ -56,7 +57,7 @@ def parse(
return options

@root_validator
def check_options_not_prohibited(
def check_options_allowed(
cls,
values,
) -> None:
Expand Down
82 changes: 82 additions & 0 deletions tests/tests_unit/test_file/test_format_unit/test_csv_unit.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,82 @@
import logging

import pytest

from onetl.file.format import CSV


def test_csv_default_options():
csv = CSV()
assert csv.delimiter == ","
assert csv.encoding == "utf-8"
assert csv.quote == '"'
assert csv.escape == "\\"
assert csv.header is False
assert csv.lineSep == "\n"


def test_csv_default_options_override():
csv = CSV(
delimiter="value",
encoding="value",
quote="value",
escape="value",
header=True,
lineSep="value",
)
assert csv.delimiter == "value"
assert csv.encoding == "value"
assert csv.quote == "value"
assert csv.escape == "value"
assert csv.header is True
assert csv.lineSep == "value"


def test_csv_delimiter_alias():
csv = CSV(sep="value")
assert csv.delimiter == "value"


@pytest.mark.parametrize(
"known_option",
[
"charToEscapeQuoteEscaping",
"dateFormat",
"emptyValue",
"ignoreLeadingWhiteSpace",
"ignoreTrailingWhiteSpace",
"nullValue",
"timestampFormat",
"timestampNTZFormat",
"columnNameOfCorruptRecord",
"comment",
"enableDateTimeParsingFallback",
"enforceSchema",
"inferSchema",
"locale",
"maxCharsPerColumn",
"maxColumns",
"mode",
"multiLine",
"nanValue",
"negativeInf",
"positiveInf",
"preferDate",
"samplingRatio",
"unescapedQuoteHandling",
"compression",
"escapeQuotes",
"quoteAll",
],
)
def test_csv_known_options(known_option):
csv = CSV(**{known_option: "value"})
assert getattr(csv, known_option) == "value"


def test_csv_unknown_options(caplog):
with caplog.at_level(logging.WARNING):
csv = CSV(unknown="abc")
assert csv.unknown == "abc"

assert ("Options ['unknown'] are not known by CSV, are you sure they are valid?") in caplog.text
Loading

0 comments on commit 2caf92f

Please sign in to comment.