Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add a snowflake load options #1516

Merged
merged 10 commits into from
Jan 10, 2023
7 changes: 7 additions & 0 deletions python-sdk/src/astro/databases/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -455,6 +455,7 @@ def load_file_to_table(
normalize_config=normalize_config,
native_support_kwargs=native_support_kwargs,
enable_native_fallback=enable_native_fallback,
load_options=load_options,
chunk_size=chunk_size,
)
else:
Expand Down Expand Up @@ -521,11 +522,13 @@ def load_file_to_table_natively_with_fallback(
native_support_kwargs: dict | None = None,
enable_native_fallback: bool | None = LOAD_FILE_ENABLE_NATIVE_FALLBACK,
chunk_size: int = DEFAULT_CHUNK_SIZE,
load_options: LoadOptions | None = None,
**kwargs,
):
"""
Load content of a file in output_table.

:param load_options:
dimberman marked this conversation as resolved.
Show resolved Hide resolved
:param source_file: File path and conn_id for object stores
:param target_table: Table to create
:param if_exists: Overwrite file if exists
Expand All @@ -542,6 +545,7 @@ def load_file_to_table_natively_with_fallback(
target_table=target_table,
if_exists=if_exists,
native_support_kwargs=native_support_kwargs,
load_options=load_options,
**kwargs,
)
except self.NATIVE_LOAD_EXCEPTIONS as load_exception: # skipcq: PYL-W0703
Expand Down Expand Up @@ -791,12 +795,15 @@ def load_file_to_table_natively(
target_table: BaseTable,
if_exists: LoadExistStrategy = "replace",
native_support_kwargs: dict | None = None,
load_options: LoadOptions | None = None,
**kwargs,
):
"""
Checks if optimised path for transfer between File location to database exists
and if it does, it transfers it and returns true else false

:param load_options: Options for database specific loading
parameters (e.g. SnowflakeLoadOptions or DeltaLoadOptions)
:param source_file: File from which we need to transfer data
:param target_table: Table that needs to be populated with file data
:param if_exists: Overwrite file if exists. Default False
Expand Down
30 changes: 24 additions & 6 deletions python-sdk/src/astro/databases/snowflake.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@
from astro.databases.base import BaseDatabase
from astro.exceptions import DatabaseCustomError
from astro.files import File
from astro.options import LoadOptions, SnowflakeLoadOptions
from astro.settings import LOAD_TABLE_AUTODETECT_ROWS_COUNT, SNOWFLAKE_SCHEMA
from astro.table import BaseTable, Metadata

Expand Down Expand Up @@ -367,6 +368,7 @@ def create_stage(
file: File,
storage_integration: str | None = None,
metadata: Metadata | None = None,
load_options: SnowflakeLoadOptions = SnowflakeLoadOptions(),
) -> SnowflakeStage:
"""
Creates a new named external stage to use for loading data from files into Snowflake
Expand Down Expand Up @@ -394,13 +396,17 @@ def create_stage(
stage.set_url_from_file(file)

fileformat = ASTRO_SDK_TO_SNOWFLAKE_FILE_FORMAT_MAP[file.type.name]
copy_options = COPY_OPTIONS[file.type.name]

copy_options = [COPY_OPTIONS[file.type.name]]
copy_options.extend([f"{k}={v}" for k, v in load_options.copy_options.items()])
file_options = [f"{k}={v}" for k, v in load_options.file_options.items()]
file_options.extend([f"TYPE={fileformat}", "TRIM_SPACE=TRUE"])
file_options_str = ", ".join(file_options)
copy_options_str = ", ".join(copy_options)
sql_statement = "".join(
[
f"CREATE OR REPLACE STAGE {stage.qualified_name} URL='{stage.url}' ",
f"FILE_FORMAT=(TYPE={fileformat}, TRIM_SPACE=TRUE) ",
f"COPY_OPTIONS=({copy_options}) ",
f"FILE_FORMAT=({file_options_str}) ",
f"COPY_OPTIONS=({copy_options_str}) ",
auth,
]
)
Expand Down Expand Up @@ -573,6 +579,7 @@ def load_file_to_table_natively(
target_table: BaseTable,
if_exists: LoadExistStrategy = "replace",
native_support_kwargs: dict | None = None,
load_options: LoadOptions | None = None,
**kwargs,
): # skipcq PYL-W0613
"""
Expand All @@ -587,6 +594,7 @@ def load_file_to_table_natively(
retrieved from the Airflow connection or from the `storage_integration`
attribute within `native_support_kwargs`.

:param load_options:
:param source_file: File from which we need to transfer data
:param target_table: Table to which the content of the file will be loaded to
:param if_exists: Strategy used to load (currently supported: "append" or "replace")
Expand All @@ -601,7 +609,13 @@ def load_file_to_table_natively(
"""
native_support_kwargs = native_support_kwargs or {}
storage_integration = native_support_kwargs.get("storage_integration")
stage = self.create_stage(file=source_file, storage_integration=storage_integration)
if not load_options:
load_options = SnowflakeLoadOptions()
if not isinstance(load_options, SnowflakeLoadOptions):
raise ValueError("Error: Requires a SnowflakeLoadOptions")
dimberman marked this conversation as resolved.
Show resolved Hide resolved
stage = self.create_stage(
file=source_file, storage_integration=storage_integration, load_options=load_options
)

table_name = self.get_table_qualified_name(target_table)
file_path = os.path.basename(source_file.path) or ""
dimberman marked this conversation as resolved.
Show resolved Hide resolved
Expand All @@ -610,6 +624,10 @@ def load_file_to_table_natively(
# Below code is added due to breaking change in apache-airflow-providers-snowflake==3.2.0,
# we need to pass handler param to get the rows. But in version apache-airflow-providers-snowflake==3.1.0
# if we pass the handler provider raises an exception AttributeError
rows = self._copy_into_table_from_stage(sql_statement, stage)
self.evaluate_results(rows)

def _copy_into_table_from_stage(self, sql_statement, stage):
try:
rows = self.hook.run(sql_statement, handler=lambda cur: cur.fetchall())
except AttributeError:
Expand All @@ -621,7 +639,7 @@ def load_file_to_table_natively(
raise DatabaseCustomError from exe
finally:
self.drop_stage(stage)
self.evaluate_results(rows)
return rows

@staticmethod
def evaluate_results(rows):
Expand Down
9 changes: 9 additions & 0 deletions python-sdk/src/astro/options.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,3 +5,12 @@
class LoadOptions:
def empty(self):
return NotImplementedError()


@attr.define
class SnowflakeLoadOptions(LoadOptions):
file_options: dict = attr.field(init=True, factory=dict)
copy_options: dict = attr.field(init=True, factory=dict)

def empty(self):
return not self.file_options and not self.copy_options
dimberman marked this conversation as resolved.
Show resolved Hide resolved
2 changes: 2 additions & 0 deletions python-sdk/src/astro/sql/operators/load_file.py
Original file line number Diff line number Diff line change
Expand Up @@ -300,6 +300,7 @@ def load_file(
native_support_kwargs: dict | None = None,
columns_names_capitalization: ColumnCapitalization = "original",
enable_native_fallback: bool | None = True,
load_options: LoadOptions | None = None,
**kwargs: Any,
) -> XComArg:
"""Load a file or bucket into either a SQL table or a pandas dataframe.
Expand Down Expand Up @@ -331,6 +332,7 @@ def load_file(
native_support_kwargs=native_support_kwargs,
columns_names_capitalization=columns_names_capitalization,
enable_native_fallback=enable_native_fallback,
load_options=load_options,
**kwargs,
).output

Expand Down
44 changes: 43 additions & 1 deletion python-sdk/tests/databases/test_snowflake.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,13 @@
"""Tests specific to the Snowflake Database implementation."""
import pathlib
from unittest.mock import patch
from unittest import mock
from unittest.mock import MagicMock, PropertyMock, patch

import pytest

from astro.databases.snowflake import SnowflakeDatabase, SnowflakeFileFormat, SnowflakeStage
from astro.files import File
from astro.options import SnowflakeLoadOptions
from astro.settings import SNOWFLAKE_STORAGE_INTEGRATION_AMAZON, SNOWFLAKE_STORAGE_INTEGRATION_GOOGLE

DEFAULT_CONN_ID = "snowflake_default"
Expand Down Expand Up @@ -74,3 +76,43 @@ def test_use_quotes(cols_eval):
Verify the quotes addition only in case where we are having mixed case col names
"""
assert SnowflakeDatabase.use_quotes(cols_eval["cols"]) == cols_eval["expected_result"]


def test_snowflake_load_options():
path = str(CWD) + "/../../data/homes_main.csv"
database = SnowflakeDatabase(conn_id="fake-conn")
file = File(path)
with mock.patch(
"astro.databases.snowflake.SnowflakeDatabase.hook", new_callable=PropertyMock
), mock.patch(
"astro.databases.snowflake.SnowflakeStage.qualified_name", new_callable=PropertyMock
) as mock_q_name:
mock_q_name.return_value = "foo"
database.run_sql = MagicMock()
database.create_stage(
file=file,
storage_integration="foo",
load_options=SnowflakeLoadOptions(file_options={"foo": "bar"}),
)
assert "FILE_FORMAT=(foo=bar, TYPE=CSV, TRIM_SPACE=TRUE)" in database.run_sql.call_args[0][0]
assert "COPY_OPTIONS=(ON_ERROR=CONTINUE)" in database.run_sql.call_args[0][0]


def test_snowflake_load_options_default():
path = str(CWD) + "/../../data/homes_main.csv"
database = SnowflakeDatabase(conn_id="fake-conn")
file = File(path)
with mock.patch(
"astro.databases.snowflake.SnowflakeDatabase.hook", new_callable=PropertyMock
), mock.patch(
"astro.databases.snowflake.SnowflakeStage.qualified_name", new_callable=PropertyMock
) as mock_q_name:
mock_q_name.return_value = "foo"
database.run_sql = MagicMock()
database.create_stage(
file=file,
storage_integration="foo",
load_options=SnowflakeLoadOptions(),
)
assert "FILE_FORMAT=(TYPE=CSV, TRIM_SPACE=TRUE)" in database.run_sql.call_args[0][0]
assert "COPY_OPTIONS=(ON_ERROR=CONTINUE)" in database.run_sql.call_args[0][0]