Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Restore Python SDK 1.4 pandas load option classes with deprecation warning #1795

Merged
merged 9 commits into from
Feb 23, 2023
2 changes: 1 addition & 1 deletion python-sdk/src/astro/databases/databricks/delta.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@


class DeltaDatabase(BaseDatabase):
LOAD_OPTIONS_CLASS_NAME = "DeltaLoadOptions"
LOAD_OPTIONS_CLASS_NAME = ("DeltaLoadOptions",)
_create_table_statement: str = "CREATE TABLE IF NOT EXISTS {} USING DELTA AS {} "

def __init__(self, conn_id: str, table: BaseTable | None = None, load_options: LoadOptions | None = None):
Expand Down
2 changes: 1 addition & 1 deletion python-sdk/src/astro/databases/snowflake.py
Original file line number Diff line number Diff line change
Expand Up @@ -242,7 +242,7 @@ class SnowflakeDatabase(BaseDatabase):
logic in other parts of our code-base.
"""

LOAD_OPTIONS_CLASS_NAME = "SnowflakeLoadOptions"
LOAD_OPTIONS_CLASS_NAME = ("SnowflakeLoadOptions",)

NATIVE_LOAD_EXCEPTIONS: Any = (
DatabaseCustomError,
Expand Down
65 changes: 65 additions & 0 deletions python-sdk/src/astro/dataframes/load_options.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
from __future__ import annotations

import warnings

import attr
from pandas._typing import DtypeArg

Expand Down Expand Up @@ -42,3 +44,66 @@ def populate_kwargs(self, kwargs):
for key in exclude_key:
kwargs.update(self.to_dict()[key])
return kwargs


class PandasCsvLoadOptions(PandasLoadOptions):
"""
Pandas load options while reading and loading csv file.
"""

def __init__(self, *args, **kwargs):
warnings.warn(
"""This class is deprecated, and will be removed in astro-sdk-python>=2.0.0.
Please use `astro.dataframe.load_options.PandasLoadOptions`.""",
DeprecationWarning,
stacklevel=2,
)
super().__init__(*args, **kwargs)


class PandasJsonLoadOptions(PandasLoadOptions):
"""
Pandas load options while reading and loading json file.

:param encoding: Encoding to use for UTF when reading/writing (ex. ‘utf-8’).
List of Python standard encodings: https://docs.python.org/3/library/codecs.html#standard-encodings
"""

def __init__(self, *args, **kwargs):
warnings.warn(
"""This class is deprecated, and will be removed in astro-sdk-python>=2.0.0.
Please use `astro.dataframe.load_options.PandasLoadOptions`.""",
DeprecationWarning,
stacklevel=2,
)
super().__init__(*args, **kwargs)


class PandasNdjsonLoadOptions(PandasLoadOptions):
"""
Pandas load options while reading and loading Ndjson file.
"""

def __init__(self, *args, **kwargs):
warnings.warn(
"""This class is deprecated, and will be removed in astro-sdk-python>=2.0.0.
Please use `astro.dataframe.load_options.PandasLoadOptions`.""",
DeprecationWarning,
stacklevel=2,
)
super().__init__(*args, **kwargs)


class PandasParquetLoadOptions(PandasLoadOptions):
"""
Pandas load options while reading and loading Parquet file.
"""

def __init__(self, *args, **kwargs):
warnings.warn(
"""This class is deprecated, and will be removed in astro-sdk-python>=2.0.0.
Please use `astro.dataframe.load_options.PandasLoadOptions`.""",
DeprecationWarning,
stacklevel=2,
)
super().__init__(*args, **kwargs)
4 changes: 2 additions & 2 deletions python-sdk/src/astro/files/locations/azure/wasb.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ class WASBLocation(BaseFileLocation):

location_type = FileLocation.WASB
supported_conn_type = {WasbHook.conn_type, "wasbs"}
LOAD_OPTIONS_CLASS_NAME = "WASBLocationLoadOptions"
LOAD_OPTIONS_CLASS_NAME = ("WASBLocationLoadOptions",)
AZURE_HOST = "blob.core.windows.net"

def exists(self) -> bool:
Expand Down Expand Up @@ -105,7 +105,7 @@ def snowflake_stage_path(self) -> str:
"""
if not contains_required_option(self.load_options, "storage_account"):
raise ValueError(
f"Required param missing 'storage_account', pass {self.LOAD_OPTIONS_CLASS_NAME}"
f"Required param missing 'storage_account', pass {self.LOAD_OPTIONS_CLASS_NAME[0]}"
f"(storage_account=<account_name>) to load_options"
)
url = urlparse(self.path)
Expand Down
2 changes: 1 addition & 1 deletion python-sdk/src/astro/files/types/csv.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
class CSVFileType(FileType):
"""Concrete implementation to handle CSV file type"""

LOAD_OPTIONS_CLASS_NAME = "PandasLoadOptions"
LOAD_OPTIONS_CLASS_NAME = ("PandasCsvLoadOptions", "PandasLoadOptions")

# We need skipcq because it's a method overloading so we don't want to make it a static method
def export_to_dataframe(
Expand Down
2 changes: 1 addition & 1 deletion python-sdk/src/astro/files/types/json.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
class JSONFileType(FileType):
"""Concrete implementation to handle JSON file type"""

LOAD_OPTIONS_CLASS_NAME = "PandasLoadOptions"
LOAD_OPTIONS_CLASS_NAME = ("PandasJsonLoadOptions", "PandasLoadOptions")

# We need skipcq because it's a method overloading so we don't want to make it a static method
def export_to_dataframe(
Expand Down
2 changes: 1 addition & 1 deletion python-sdk/src/astro/files/types/ndjson.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
class NDJSONFileType(FileType):
"""Concrete implementation to handle NDJSON file type"""

LOAD_OPTIONS_CLASS_NAME = "PandasLoadOptions"
LOAD_OPTIONS_CLASS_NAME = ("PandasNdjsonLoadOptions", "PandasLoadOptions")

def export_to_dataframe(
self,
Expand Down
2 changes: 1 addition & 1 deletion python-sdk/src/astro/files/types/parquet.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
class ParquetFileType(FileType):
"""Concrete implementation to handle Parquet file type"""

LOAD_OPTIONS_CLASS_NAME = "PandasLoadOptions"
LOAD_OPTIONS_CLASS_NAME = ("PandasParquetLoadOptions", "PandasLoadOptions")

def export_to_dataframe(
self,
Expand Down
7 changes: 6 additions & 1 deletion python-sdk/src/astro/options.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,12 @@ def get(self, option_class) -> Optional[LoadOptions]:
"""
if not hasattr(option_class, "LOAD_OPTIONS_CLASS_NAME"):
return None
return self.get_by_class_name(option_class.LOAD_OPTIONS_CLASS_NAME)
cls = None
for cls_name in option_class.LOAD_OPTIONS_CLASS_NAME:
cls = self.get_by_class_name(cls_name)
if cls is not None:
break
return cls

def get_by_class_name(self, option_class_name) -> Optional[LoadOptions]:
"""
Expand Down
34 changes: 32 additions & 2 deletions python-sdk/src/astro/sql/operators/load_file.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,12 @@
from astro.constants import DEFAULT_CHUNK_SIZE, ColumnCapitalization, LoadExistStrategy
from astro.databases import create_database
from astro.databases.base import BaseDatabase
from astro.dataframes.load_options import (
PandasCsvLoadOptions,
PandasJsonLoadOptions,
PandasNdjsonLoadOptions,
PandasParquetLoadOptions,
)
from astro.dataframes.pandas import PandasDataframe
from astro.files import File, resolve_file_path_pattern
from astro.options import LoadOptions, LoadOptionsList
Expand All @@ -20,6 +26,13 @@
from astro.table import BaseTable
from astro.utils.compat.typing import Context

DEPRECATED_LOAD_OPTIONS_CLASSES = [
PandasCsvLoadOptions,
PandasJsonLoadOptions,
PandasNdjsonLoadOptions,
PandasParquetLoadOptions,
]


class LoadFileOperator(AstroSQLBaseOperator):
"""Load S3/local file into either a database or a pandas dataframe
Expand Down Expand Up @@ -50,7 +63,7 @@ def __init__(
ndjson_normalize_sep: str = "_",
use_native_support: bool = True,
native_support_kwargs: dict | None = None,
load_options: list[LoadOptions] | None = None,
load_options: LoadOptions | list[LoadOptions] | None = None,
columns_names_capitalization: ColumnCapitalization = "original",
enable_native_fallback: bool | None = LOAD_FILE_ENABLE_NATIVE_FALLBACK,
**kwargs,
Expand All @@ -70,6 +83,23 @@ def __init__(
DeprecationWarning,
stacklevel=2,
)

if load_options is not None:
if not isinstance(load_options, list):
load_options = [load_options]

deprecated_classes = {type(cls).__name__ for cls in load_options}.intersection(
[cls.__name__ for cls in DEPRECATED_LOAD_OPTIONS_CLASSES]
)
if deprecated_classes:
warnings.warn(
f'`{", ".join(deprecated_classes)}` will be replaced by'
f" `astro.dataframes.load_options.PandasLoadOptions` in astro-sdk-python>=2.0.0."
f" Please use `astro.dataframes.load_options.PandasLoadOptions` class instead.",
DeprecationWarning,
stacklevel=2,
)

self.output_table = output_table
self.input_file = input_file
self.input_file.load_options = load_options
Expand Down Expand Up @@ -314,7 +344,7 @@ def load_file(
native_support_kwargs: dict | None = None,
columns_names_capitalization: ColumnCapitalization = "original",
enable_native_fallback: bool | None = True,
load_options: list[LoadOptions] | None = None,
load_options: LoadOptions | list[LoadOptions] | None = None,
**kwargs: Any,
) -> XComArg:
"""Load a file or bucket into either a SQL table or a pandas dataframe.
Expand Down
2 changes: 1 addition & 1 deletion python-sdk/tests/files/locations/test_wasb.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,7 @@ def test_snowflake_stage_path_raise_exception():
Test snowflake_stage_path raise exception when 'storage_account' is missing.
"""
location = WASBLocation(path="azure://somepath")
error_message = f"Required param missing 'storage_account', pass {location.LOAD_OPTIONS_CLASS_NAME}"
error_message = f"Required param missing 'storage_account', pass {location.LOAD_OPTIONS_CLASS_NAME[0]}"
"(storage_account=<account_name>) to load_options"
with pytest.raises(ValueError, match=error_message):
location.snowflake_stage_path
Expand Down
46 changes: 45 additions & 1 deletion python-sdk/tests/files/test_file.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,13 @@
from airflow import DAG

from astro import constants
from astro.dataframes.load_options import PandasLoadOptions
from astro.dataframes.load_options import (
PandasCsvLoadOptions,
PandasJsonLoadOptions,
PandasLoadOptions,
PandasNdjsonLoadOptions,
PandasParquetLoadOptions,
)
from astro.files import File, get_file_list, resolve_file_path_pattern
from astro.options import SnowflakeLoadOptions, WASBLocationLoadOptions

Expand Down Expand Up @@ -256,3 +262,41 @@ def test_file_object_picks_load_options(file_type, file_location):
]
assert type(file.type.load_options) is type_expected_class
assert file.location.load_options is location_expected_class


@pytest.mark.parametrize(
"file_type",
[
{"type": "csv", "expected_class": PandasCsvLoadOptions},
{"type": "ndjson", "expected_class": PandasNdjsonLoadOptions},
{"type": "json", "expected_class": PandasJsonLoadOptions},
{"type": "parquet", "expected_class": PandasParquetLoadOptions},
],
ids=["csv", "ndjson", "json", "parquet"],
)
@pytest.mark.parametrize(
"file_location",
[
{"location": "s3://dummy/test", "expected_class": None},
{"location": "gs://dummy/test", "expected_class": None},
{"location": "ftp://dummy/test", "expected_class": None},
{"location": "sftp://dummy/test", "expected_class": None},
{"location": "gdrive://dummy/test", "expected_class": None},
{"location": "http://dummy.com/test", "expected_class": None},
{"location": "https://dummy.com/test", "expected_class": None},
{"location": "./test", "expected_class": None}, # local path
],
ids=["s3", "gs", "ftp", "sftp", "gdrive", "http", "https", "local"],
)
def test_file_object_picks_load_options_with_deprecated_load_options(file_type, file_location):
"""Test file object pick correct load_options"""
type_name, type_expected_class = file_type.values()
location_path, _ = file_location.values()
file = File(path=location_path + f".{type_name}")
file.load_options = [
PandasCsvLoadOptions(delimiter="$"),
PandasJsonLoadOptions(encoding="test"),
PandasParquetLoadOptions(columns=["name", "age"]),
PandasNdjsonLoadOptions(normalize_sep="__"),
]
assert type(file.type.load_options) is type_expected_class
12 changes: 11 additions & 1 deletion python-sdk/tests/files/type/test_csv.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@

import pandas as pd

from astro.dataframes.load_options import PandasLoadOptions
from astro.dataframes.load_options import PandasCsvLoadOptions, PandasLoadOptions
from astro.dataframes.pandas import PandasDataframe
from astro.files.types import CSVFileType

Expand All @@ -31,6 +31,16 @@ def test_read_csv_file_with_pandas_opts(mock_read_csv):
mock_read_csv.assert_called_once_with(file, delimiter="$")


@mock.patch("astro.files.types.csv.pd.read_csv")
def test_read_csv_file_with_pandas_opts_with_deprecated_load_options(mock_read_csv):
"""Test pandas option get pass to read_csv"""
path = str(sample_file.absolute())
csv_type = CSVFileType(path, load_options=PandasCsvLoadOptions(delimiter="$"))
with open(path) as file:
csv_type.export_to_dataframe(file)
mock_read_csv.assert_called_once_with(file, delimiter="$")


def test_write_csv_file():
"""Test writing of csv file from local location"""
with tempfile.NamedTemporaryFile() as temp_file:
Expand Down
12 changes: 11 additions & 1 deletion python-sdk/tests/files/type/test_json.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@

import pandas as pd

from astro.dataframes.load_options import PandasLoadOptions
from astro.dataframes.load_options import PandasJsonLoadOptions, PandasLoadOptions
from astro.dataframes.pandas import PandasDataframe
from astro.files.types import JSONFileType

Expand All @@ -31,6 +31,16 @@ def test_read_json_file_with_pandas_opts(mock_read_json):
mock_read_json.assert_called_once_with(file, encoding="utf-8")


@mock.patch("astro.files.types.json.pd.read_json")
def test_read_json_file_with_pandas_opts_with_deprecated_load_options(mock_read_json):
"""Test pandas option get pass to read_json"""
path = str(sample_file.absolute())
json_type = JSONFileType(path, load_options=PandasJsonLoadOptions(encoding="utf-8"))
with open(path) as file:
json_type.export_to_dataframe(file)
mock_read_json.assert_called_once_with(file, encoding="utf-8")


def test_write_json_file():
"""Test writing of json file from local location"""
with tempfile.NamedTemporaryFile() as temp_file:
Expand Down
12 changes: 11 additions & 1 deletion python-sdk/tests/files/type/test_ndjson.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@

import pandas as pd

from astro.dataframes.load_options import PandasLoadOptions
from astro.dataframes.load_options import PandasLoadOptions, PandasNdjsonLoadOptions
from astro.dataframes.pandas import PandasDataframe
from astro.files.types import NDJSONFileType
from astro.settings import NEED_CUSTOM_SERIALIZATION
Expand Down Expand Up @@ -36,6 +36,16 @@ def test_read_ndjson_file_with_pandas_opts(mock_ndjson_flatten):
mock_ndjson_flatten.assert_called_once_with(None, file, normalize_sep="_")


@mock.patch("astro.files.types.ndjson.NDJSONFileType.flatten")
def test_read_ndjson_file_with_pandas_opts_with_deprecated_load_options(mock_ndjson_flatten):
"""Test pandas option get pass to ndjson_flatten method"""
path = str(sample_file.absolute())
ndjson_type = NDJSONFileType(path, load_options=PandasNdjsonLoadOptions(normalize_sep="_"))
with open(path) as file:
ndjson_type.export_to_dataframe(file)
mock_ndjson_flatten.assert_called_once_with(None, file, normalize_sep="_")


def test_write_ndjson_file():
"""Test writing of ndjson file from local location"""
with tempfile.NamedTemporaryFile() as temp_file:
Expand Down
Loading