Skip to content

Commit 76270f6

Browse files
authored
fix: Redshift push ignores schema (#3671)
* Add fully-qualified-table-name Redshift prop Signed-off-by: Robin Neufeld <metavee@users.noreply.github.com> * pre-commit Signed-off-by: Robin Neufeld <metavee@users.noreply.github.com> * Docstring Signed-off-by: Robin Neufeld <metavee@users.noreply.github.com> * Test fully_qualified_table_name Signed-off-by: Robin Neufeld <metavee@users.noreply.github.com> * Simplify logic Signed-off-by: Robin Neufeld <metavee@users.noreply.github.com> * pre-commit Signed-off-by: Robin Neufeld <metavee@users.noreply.github.com> * pre-commit Signed-off-by: Robin Neufeld <metavee@users.noreply.github.com> * Test offline_write_batch Signed-off-by: Robin Neufeld <metavee@users.noreply.github.com> * Bump to trigger CI Signed-off-by: Robin Neufeld <metavee@users.noreply.github.com> * another bump for ci Signed-off-by: Robin Neufeld <metavee@users.noreply.github.com> --------- Signed-off-by: Robin Neufeld <metavee@users.noreply.github.com>
1 parent 9527183 commit 76270f6

File tree

4 files changed

+147
-2
lines changed

4 files changed

+147
-2
lines changed

sdk/python/feast/infra/offline_stores/redshift.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -369,7 +369,7 @@ def offline_write_batch(
369369
s3_resource=s3_resource,
370370
s3_path=f"{config.offline_store.s3_staging_location}/push/{uuid.uuid4()}.parquet",
371371
iam_role=config.offline_store.iam_role,
372-
table_name=redshift_options.table,
372+
table_name=redshift_options.fully_qualified_table_name,
373373
schema=pa_schema,
374374
fail_if_exists=False,
375375
)

sdk/python/feast/infra/offline_stores/redshift_source.py

+36-1
Original file line numberDiff line numberDiff line change
@@ -294,6 +294,42 @@ def from_proto(cls, redshift_options_proto: DataSourceProto.RedshiftOptions):
294294

295295
return redshift_options
296296

297+
@property
298+
def fully_qualified_table_name(self) -> str:
299+
"""
300+
The fully qualified table name of this Redshift table.
301+
302+
Returns:
303+
A string in the format of <database>.<schema>.<table>
304+
May be empty or None if the table is not set
305+
"""
306+
307+
if not self.table:
308+
return ""
309+
310+
# self.table may already contain the database and schema
311+
parts = self.table.split(".")
312+
if len(parts) == 3:
313+
database, schema, table = parts
314+
elif len(parts) == 2:
315+
database = self.database
316+
schema, table = parts
317+
elif len(parts) == 1:
318+
database = self.database
319+
schema = self.schema
320+
table = parts[0]
321+
else:
322+
raise ValueError(
323+
f"Invalid table name: {self.table} - can't determine database and schema"
324+
)
325+
326+
if database and schema:
327+
return f"{database}.{schema}.{table}"
328+
elif schema:
329+
return f"{schema}.{table}"
330+
else:
331+
return table
332+
297333
def to_proto(self) -> DataSourceProto.RedshiftOptions:
298334
"""
299335
Converts an RedshiftOptionsProto object to its protobuf representation.
@@ -323,7 +359,6 @@ def __init__(self, table_ref: str):
323359

324360
@staticmethod
325361
def from_proto(storage_proto: SavedDatasetStorageProto) -> SavedDatasetStorage:
326-
327362
return SavedDatasetRedshiftStorage(
328363
table_ref=RedshiftOptions.from_proto(storage_proto.redshift_storage).table
329364
)
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,67 @@
1+
from unittest.mock import MagicMock, patch
2+
3+
import pandas as pd
4+
import pyarrow as pa
5+
6+
from feast import FeatureView
7+
from feast.infra.offline_stores import offline_utils
8+
from feast.infra.offline_stores.redshift import (
9+
RedshiftOfflineStore,
10+
RedshiftOfflineStoreConfig,
11+
)
12+
from feast.infra.offline_stores.redshift_source import RedshiftSource
13+
from feast.infra.utils import aws_utils
14+
from feast.repo_config import RepoConfig
15+
16+
17+
@patch.object(aws_utils, "upload_arrow_table_to_redshift")
18+
def test_offline_write_batch(
19+
mock_upload_arrow_table_to_redshift: MagicMock,
20+
simple_dataset_1: pd.DataFrame,
21+
):
22+
repo_config = RepoConfig(
23+
registry="registry",
24+
project="project",
25+
provider="local",
26+
offline_store=RedshiftOfflineStoreConfig(
27+
type="redshift",
28+
region="us-west-2",
29+
cluster_id="cluster_id",
30+
database="database",
31+
user="user",
32+
iam_role="abcdef",
33+
s3_staging_location="s3://bucket/path",
34+
),
35+
)
36+
37+
batch_source = RedshiftSource(
38+
name="test_source",
39+
timestamp_field="ts",
40+
table="table_name",
41+
schema="schema_name",
42+
)
43+
feature_view = FeatureView(
44+
name="test_view",
45+
source=batch_source,
46+
)
47+
48+
pa_dataset = pa.Table.from_pandas(simple_dataset_1)
49+
50+
# patch some more things so that the function can run
51+
def mock_get_pyarrow_schema_from_batch_source(*args, **kwargs) -> pa.Schema:
52+
return pa_dataset.schema, pa_dataset.column_names
53+
54+
with patch.object(
55+
offline_utils,
56+
"get_pyarrow_schema_from_batch_source",
57+
new=mock_get_pyarrow_schema_from_batch_source,
58+
):
59+
RedshiftOfflineStore.offline_write_batch(
60+
repo_config, feature_view, pa_dataset, progress=None
61+
)
62+
63+
# check that we have included the fully qualified table name
64+
mock_upload_arrow_table_to_redshift.assert_called_once()
65+
66+
call = mock_upload_arrow_table_to_redshift.call_args_list[0]
67+
assert call.kwargs["table_name"] == "schema_name.table_name"

sdk/python/tests/unit/test_data_sources.py

+43
Original file line numberDiff line numberDiff line change
@@ -190,3 +190,46 @@ def test_column_conflict():
190190
timestamp_field="event_timestamp",
191191
created_timestamp_column="event_timestamp",
192192
)
193+
194+
195+
@pytest.mark.parametrize(
196+
"source_kwargs,expected_name",
197+
[
198+
(
199+
{
200+
"database": "test_database",
201+
"schema": "test_schema",
202+
"table": "test_table",
203+
},
204+
"test_database.test_schema.test_table",
205+
),
206+
(
207+
{"database": "test_database", "table": "test_table"},
208+
"test_database.public.test_table",
209+
),
210+
({"table": "test_table"}, "public.test_table"),
211+
({"database": "test_database", "table": "b.c"}, "test_database.b.c"),
212+
({"database": "test_database", "table": "a.b.c"}, "a.b.c"),
213+
(
214+
{
215+
"database": "test_database",
216+
"schema": "test_schema",
217+
"query": "select * from abc",
218+
},
219+
"",
220+
),
221+
],
222+
)
223+
def test_redshift_fully_qualified_table_name(source_kwargs, expected_name):
224+
redshift_source = RedshiftSource(
225+
name="test_source",
226+
timestamp_field="event_timestamp",
227+
created_timestamp_column="created_timestamp",
228+
field_mapping={"foo": "bar"},
229+
description="test description",
230+
tags={"test": "test"},
231+
owner="test@gmail.com",
232+
**source_kwargs,
233+
)
234+
235+
assert redshift_source.redshift_options.fully_qualified_table_name == expected_name

0 commit comments

Comments
 (0)