Skip to content

Commit

Permalink
Add a snowflake load options (#1516)
Browse files Browse the repository at this point in the history
# Description
## What is the current behavior?
The current behavior for loading data to Snowflake is to depend on
`native_kwargs`, which are loosely defined and lack type safety. There
is also no way to specify `FILE_FORMAT` options in Snowflake, which can
block users from being able to load their data.

## What is the new behavior?
This PR introduces a `SnowflakeLoadOptions` class that contains
`file_format_options` and `copy_options` fields. These fields allow
users to specify the necessary arguments to correctly load their data.

## Does this introduce a breaking change?
No, this does not introduce a breaking change.

### Checklist
- [ ] Created tests which fail without the change (if possible)
- [ ] Extended the README / documentation, if necessary
  • Loading branch information
dimberman authored Jan 10, 2023
1 parent 116cf74 commit acd7ee4
Show file tree
Hide file tree
Showing 6 changed files with 116 additions and 9 deletions.
3 changes: 3 additions & 0 deletions python-sdk/src/astro/databases/aws/redshift.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@
from astro.databases.base import BaseDatabase
from astro.exceptions import DatabaseCustomError
from astro.files import File
from astro.options import LoadOptions
from astro.settings import REDSHIFT_SCHEMA
from astro.table import BaseTable, Metadata, Table

Expand Down Expand Up @@ -315,12 +316,14 @@ def load_file_to_table_natively(
target_table: BaseTable,
if_exists: LoadExistStrategy = "replace",
native_support_kwargs: dict | None = None,
load_options: LoadOptions | None = None,
**kwargs,
):
"""
Checks if optimised path for transfer between File location to database exists
and if it does, it transfers it and returns true else false.
:param load_options: Database specific options for loading
:param source_file: File from which we need to transfer data
:param target_table: Table that needs to be populated with file data
:param if_exists: Overwrite file if exists. Default False
Expand Down
4 changes: 4 additions & 0 deletions python-sdk/src/astro/databases/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -562,6 +562,7 @@ def load_file_to_table_natively_with_fallback(
target_table=target_table,
if_exists=if_exists,
native_support_kwargs=native_support_kwargs,
load_options=load_options,
**kwargs,
)
except self.NATIVE_LOAD_EXCEPTIONS as load_exception: # skipcq: PYL-W0703
Expand Down Expand Up @@ -812,12 +813,15 @@ def load_file_to_table_natively(
target_table: BaseTable,
if_exists: LoadExistStrategy = "replace",
native_support_kwargs: dict | None = None,
load_options: LoadOptions | None = None,
**kwargs,
):
"""
Checks if optimised path for transfer between File location to database exists
and if it does, it transfers it and returns true else false
:param load_options: Options for database specific loading
parameters (e.g. SnowflakeLoadOptions or DeltaLoadOptions)
:param source_file: File from which we need to transfer data
:param target_table: Table that needs to be populated with file data
:param if_exists: Overwrite file if exists. Default False
Expand Down
3 changes: 3 additions & 0 deletions python-sdk/src/astro/databases/google/bigquery.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,7 @@
from astro.databases.base import BaseDatabase
from astro.exceptions import DatabaseCustomError
from astro.files import File
from astro.options import LoadOptions
from astro.settings import BIGQUERY_SCHEMA, BIGQUERY_SCHEMA_LOCATION
from astro.table import BaseTable, Metadata

Expand Down Expand Up @@ -347,12 +348,14 @@ def load_file_to_table_natively(
target_table: BaseTable,
if_exists: LoadExistStrategy = "replace",
native_support_kwargs: dict | None = None,
load_options: LoadOptions | None = None,
**kwargs,
):
"""
Checks if optimised path for transfer between File location to database exists
and if it does, it transfers it and returns true else false.
:param load_options: Database specific options for loading
:param source_file: File from which we need to transfer data
:param target_table: Table that needs to be populated with file data
:param if_exists: Overwrite file if exists. Default False
Expand Down
35 changes: 28 additions & 7 deletions python-sdk/src/astro/databases/snowflake.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@
from astro.databases.base import BaseDatabase
from astro.exceptions import DatabaseCustomError
from astro.files import File
from astro.options import LoadOptions
from astro.options import LoadOptions, SnowflakeLoadOptions
from astro.settings import LOAD_TABLE_AUTODETECT_ROWS_COUNT, SNOWFLAKE_SCHEMA
from astro.table import BaseTable, Metadata

Expand Down Expand Up @@ -368,6 +368,7 @@ def create_stage(
file: File,
storage_integration: str | None = None,
metadata: Metadata | None = None,
load_options: SnowflakeLoadOptions | None = None,
) -> SnowflakeStage:
"""
Creates a new named external stage to use for loading data from files into Snowflake
Expand All @@ -390,18 +391,24 @@ def create_stage(
"""
auth = self._create_stage_auth_sub_statement(file=file, storage_integration=storage_integration)

if not load_options:
load_options = SnowflakeLoadOptions()
metadata = metadata or self.default_metadata
stage = SnowflakeStage(metadata=metadata)
stage.set_url_from_file(file)

fileformat = ASTRO_SDK_TO_SNOWFLAKE_FILE_FORMAT_MAP[file.type.name]
copy_options = COPY_OPTIONS[file.type.name]

copy_options = [COPY_OPTIONS[file.type.name]]
copy_options.extend([f"{k}={v}" for k, v in load_options.copy_options.items()])
file_options = [f"{k}={v}" for k, v in load_options.file_options.items()]
file_options.extend([f"TYPE={fileformat}", "TRIM_SPACE=TRUE"])
file_options_str = ", ".join(file_options)
copy_options_str = ", ".join(copy_options)
sql_statement = "".join(
[
f"CREATE OR REPLACE STAGE {stage.qualified_name} URL='{stage.url}' ",
f"FILE_FORMAT=(TYPE={fileformat}, TRIM_SPACE=TRUE) ",
f"COPY_OPTIONS=({copy_options}) ",
f"FILE_FORMAT=({file_options_str}) ",
f"COPY_OPTIONS=({copy_options_str}) ",
auth,
]
)
Expand Down Expand Up @@ -577,6 +584,7 @@ def load_file_to_table_natively(
target_table: BaseTable,
if_exists: LoadExistStrategy = "replace",
native_support_kwargs: dict | None = None,
load_options: LoadOptions | None = None,
**kwargs,
): # skipcq PYL-W0613
"""
Expand All @@ -591,6 +599,7 @@ def load_file_to_table_natively(
retrieved from the Airflow connection or from the `storage_integration`
attribute within `native_support_kwargs`.
:param load_options: Options for format and copy options when loading data to snowflake
:param source_file: File from which we need to transfer data
:param target_table: Table to which the content of the file will be loaded to
:param if_exists: Strategy used to load (currently supported: "append" or "replace")
Expand All @@ -605,8 +614,20 @@ def load_file_to_table_natively(
"""
native_support_kwargs = native_support_kwargs or {}
storage_integration = native_support_kwargs.get("storage_integration")
stage = self.create_stage(file=source_file, storage_integration=storage_integration)
if not load_options:
load_options = SnowflakeLoadOptions()
if not isinstance(load_options, SnowflakeLoadOptions):
raise ValueError("Error: Requires a SnowflakeLoadOptions")
stage = self.create_stage(
file=source_file, storage_integration=storage_integration, load_options=load_options
)

rows = self._copy_into_table_from_stage(
source_file=source_file, target_table=target_table, stage=stage
)
self.evaluate_results(rows)

def _copy_into_table_from_stage(self, source_file, target_table, stage):
table_name = self.get_table_qualified_name(target_table)
file_path = os.path.basename(source_file.path) or ""
sql_statement = f"COPY INTO {table_name} FROM @{stage.qualified_name}/{file_path}"
Expand All @@ -625,7 +646,7 @@ def load_file_to_table_natively(
raise DatabaseCustomError from exe
finally:
self.drop_stage(stage)
self.evaluate_results(rows)
return rows

@staticmethod
def evaluate_results(rows):
Expand Down
9 changes: 9 additions & 0 deletions python-sdk/src/astro/options.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,3 +5,12 @@
class LoadOptions:
def empty(self):
return NotImplementedError()


@attr.define
class SnowflakeLoadOptions(LoadOptions):
file_options: dict = attr.field(init=True, factory=dict)
copy_options: dict = attr.field(init=True, factory=dict)

def empty(self):
return not self.file_options and not self.copy_options
71 changes: 69 additions & 2 deletions python-sdk/tests/databases/test_snowflake.py
Original file line number Diff line number Diff line change
@@ -1,19 +1,21 @@
"""Tests specific to the Snowflake Database implementation."""
import pathlib
from unittest.mock import patch
from unittest import mock
from unittest.mock import MagicMock, PropertyMock, patch

import pytest

from astro.databases.snowflake import SnowflakeDatabase, SnowflakeFileFormat, SnowflakeStage
from astro.files import File
from astro.options import LoadOptions, SnowflakeLoadOptions
from astro.settings import SNOWFLAKE_STORAGE_INTEGRATION_AMAZON, SNOWFLAKE_STORAGE_INTEGRATION_GOOGLE
from astro.table import Table

DEFAULT_CONN_ID = "snowflake_default"
CUSTOM_CONN_ID = "snowflake_conn"
SUPPORTED_CONN_IDS = [CUSTOM_CONN_ID]
CWD = pathlib.Path(__file__).parent


SNOWFLAKE_STORAGE_INTEGRATION_AMAZON = SNOWFLAKE_STORAGE_INTEGRATION_AMAZON or "aws_int_python_sdk"
SNOWFLAKE_STORAGE_INTEGRATION_GOOGLE = SNOWFLAKE_STORAGE_INTEGRATION_GOOGLE or "gcs_int_python_sdk"

Expand Down Expand Up @@ -74,3 +76,68 @@ def test_use_quotes(cols_eval):
Verify the quotes addition only in case where we are having mixed case col names
"""
assert SnowflakeDatabase.use_quotes(cols_eval["cols"]) == cols_eval["expected_result"]


def test_snowflake_load_options():
path = str(CWD) + "/../../data/homes_main.csv"
database = SnowflakeDatabase(conn_id="fake-conn")
file = File(path)
with mock.patch(
"astro.databases.snowflake.SnowflakeDatabase.hook", new_callable=PropertyMock
), mock.patch(
"astro.databases.snowflake.SnowflakeStage.qualified_name", new_callable=PropertyMock
) as mock_q_name:
mock_q_name.return_value = "foo"
database.run_sql = MagicMock()
database.create_stage(
file=file,
storage_integration="foo",
load_options=SnowflakeLoadOptions(file_options={"foo": "bar"}),
)
assert "FILE_FORMAT=(foo=bar, TYPE=CSV, TRIM_SPACE=TRUE)" in database.run_sql.call_args[0][0]
assert "COPY_OPTIONS=(ON_ERROR=CONTINUE)" in database.run_sql.call_args[0][0]


def test_snowflake_load_options_default():
path = str(CWD) + "/../../data/homes_main.csv"
database = SnowflakeDatabase(conn_id="fake-conn")
file = File(path)
with mock.patch(
"astro.databases.snowflake.SnowflakeDatabase.hook", new_callable=PropertyMock
), mock.patch(
"astro.databases.snowflake.SnowflakeStage.qualified_name", new_callable=PropertyMock
) as mock_q_name:
mock_q_name.return_value = "foo"
database.run_sql = MagicMock()
database.create_stage(
file=file,
storage_integration="foo",
load_options=SnowflakeLoadOptions(),
)
assert "FILE_FORMAT=(TYPE=CSV, TRIM_SPACE=TRUE)" in database.run_sql.call_args[0][0]
assert "COPY_OPTIONS=(ON_ERROR=CONTINUE)" in database.run_sql.call_args[0][0]


def test_snowflake_load_options_wrong_options():
path = str(CWD) + "/../../data/homes_main.csv"
database = SnowflakeDatabase(conn_id="fake-conn")
file = File(path)
with pytest.raises(ValueError, match="Error: Requires a SnowflakeLoadOptions"):
database.load_file_to_table_natively(
source_file=file,
target_table=Table(),
load_options=LoadOptions(),
)


def test_snowflake_load_options_empty():
load_options = SnowflakeLoadOptions()
assert load_options.empty()
load_options.copy_options = {"foo": "bar"}
assert not load_options.empty()
load_options.file_options = {"biz": "baz"}
assert not load_options.empty()
load_options.copy_options = {}
assert not load_options.empty()
load_options.file_options = {}
assert load_options.empty()

0 comments on commit acd7ee4

Please sign in to comment.