Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix: Update redshift api #2479

Merged
merged 3 commits into from
Apr 4, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions protos/feast/core/DataSource.proto
Original file line number Diff line number Diff line change
Expand Up @@ -145,6 +145,9 @@ message DataSource {

// Redshift schema name
string schema = 3;

// Redshift database name
string database = 4;
}

// Defines options for DataSource that sources features from a Snowflake Query
Expand Down
47 changes: 39 additions & 8 deletions sdk/python/feast/infra/offline_stores/redshift_source.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@ def __init__(
description: Optional[str] = "",
tags: Optional[Dict[str, str]] = None,
owner: Optional[str] = "",
database: Optional[str] = "",
):
"""
Creates a RedshiftSource object.
Expand All @@ -47,11 +48,12 @@ def __init__(
tags (optional): A dictionary of key-value pairs to store arbitrary metadata.
owner (optional): The owner of the redshift source, typically the email of the primary
maintainer.
database (optional): The Redshift database name.
"""
# The default Redshift schema is named "public".
_schema = "public" if table and not schema else schema
self.redshift_options = RedshiftOptions(
table=table, schema=_schema, query=query
table=table, schema=_schema, query=query, database=database
)

if table is None and query is None:
Expand Down Expand Up @@ -102,6 +104,7 @@ def from_proto(data_source: DataSourceProto):
description=data_source.description,
tags=dict(data_source.tags),
owner=data_source.owner,
database=data_source.redshift_options.database,
)

# Note: Python requires redefining hash in child classes that override __eq__
Expand All @@ -119,6 +122,7 @@ def __eq__(self, other):
and self.redshift_options.table == other.redshift_options.table
and self.redshift_options.schema == other.redshift_options.schema
and self.redshift_options.query == other.redshift_options.query
and self.redshift_options.database == other.redshift_options.database
and self.event_timestamp_column == other.event_timestamp_column
and self.created_timestamp_column == other.created_timestamp_column
and self.field_mapping == other.field_mapping
Expand All @@ -139,9 +143,14 @@ def schema(self):

@property
def query(self):
"""Returns the Redshift options of this Redshift source."""
"""Returns the Redshift query of this Redshift source."""
return self.redshift_options.query

@property
def database(self):
"""Returns the Redshift database of this Redshift source."""
return self.redshift_options.database

def to_proto(self) -> DataSourceProto:
"""
Converts a RedshiftSource object to its protobuf representation.
Expand Down Expand Up @@ -197,12 +206,15 @@ def get_table_column_names_and_types(
assert isinstance(config.offline_store, RedshiftOfflineStoreConfig)

client = aws_utils.get_redshift_data_client(config.offline_store.region)

if self.table is not None:
try:
table = client.describe_table(
ClusterIdentifier=config.offline_store.cluster_id,
Database=config.offline_store.database,
Database=(
self.database
if self.database
else config.offline_store.database
),
DbUser=config.offline_store.user,
Table=self.table,
Schema=self.schema,
Expand All @@ -221,7 +233,7 @@ def get_table_column_names_and_types(
statement_id = aws_utils.execute_redshift_statement(
client,
config.offline_store.cluster_id,
config.offline_store.database,
self.database if self.database else config.offline_store.database,
config.offline_store.user,
f"SELECT * FROM ({self.query}) LIMIT 1",
)
Expand All @@ -238,11 +250,16 @@ class RedshiftOptions:
"""

def __init__(
self, table: Optional[str], schema: Optional[str], query: Optional[str]
self,
table: Optional[str],
schema: Optional[str],
query: Optional[str],
database: Optional[str],
):
self._table = table
self._schema = schema
self._query = query
self._database = database

@property
def query(self):
Expand Down Expand Up @@ -274,6 +291,16 @@ def schema(self, schema):
"""Sets the schema of this Redshift table."""
self._schema = schema

@property
def database(self):
"""Returns the schema name of this Redshift table."""
return self._database

@database.setter
def database(self, database):
"""Sets the database name of this Redshift table."""
self._database = database

@classmethod
def from_proto(cls, redshift_options_proto: DataSourceProto.RedshiftOptions):
"""
Expand All @@ -289,6 +316,7 @@ def from_proto(cls, redshift_options_proto: DataSourceProto.RedshiftOptions):
table=redshift_options_proto.table,
schema=redshift_options_proto.schema,
query=redshift_options_proto.query,
database=redshift_options_proto.database,
)

return redshift_options
Expand All @@ -301,7 +329,10 @@ def to_proto(self) -> DataSourceProto.RedshiftOptions:
A RedshiftOptionsProto protobuf.
"""
redshift_options_proto = DataSourceProto.RedshiftOptions(
table=self.table, schema=self.schema, query=self.query,
table=self.table,
schema=self.schema,
query=self.query,
database=self.database,
)

return redshift_options_proto
Expand All @@ -314,7 +345,7 @@ class SavedDatasetRedshiftStorage(SavedDatasetStorage):

def __init__(self, table_ref: str):
self.redshift_options = RedshiftOptions(
table=table_ref, schema=None, query=None
table=table_ref, schema=None, query=None, database=None
)

@staticmethod
Expand Down
5 changes: 4 additions & 1 deletion sdk/python/feast/templates/aws/bootstrap.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,14 +53,17 @@ def bootstrap():

repo_path = pathlib.Path(__file__).parent.absolute()
config_file = repo_path / "feature_store.yaml"
driver_file = repo_path / "driver_repo.py"

replace_str_in_file(config_file, "%AWS_REGION%", aws_region)
replace_str_in_file(config_file, "%REDSHIFT_CLUSTER_ID%", cluster_id)
replace_str_in_file(config_file, "%REDSHIFT_DATABASE%", database)
replace_str_in_file(driver_file, "%REDSHIFT_DATABASE%", database)
replace_str_in_file(config_file, "%REDSHIFT_USER%", user)
replace_str_in_file(
config_file, "%REDSHIFT_S3_STAGING_LOCATION%", s3_staging_location
driver_file, config_file, "%REDSHIFT_S3_STAGING_LOCATION%", s3_staging_location
)
replace_str_in_file(config_file,)
replace_str_in_file(config_file, "%REDSHIFT_IAM_ROLE%", iam_role)


Expand Down
2 changes: 2 additions & 0 deletions sdk/python/feast/templates/aws/driver_repo.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,8 @@
# The (optional) created timestamp is used to ensure there are no duplicate
# feature rows in the offline store or when building training datasets
created_timestamp_column="created",
# Database to redshift source.
database="%REDSHIFT_DATABASE%",
)

# Feature views are a grouping based on how features are stored in either the
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,7 @@ def create_data_source(
created_timestamp_column=created_timestamp_column,
date_partition_column="",
field_mapping=field_mapping or {"ts_1": "ts"},
database=self.offline_store_config.database,
)

def create_saved_dataset_destination(self) -> SavedDatasetRedshiftStorage:
Expand Down