Skip to content

Commit 5abed74

Browse files
feat: Added SnowflakeConnection caching
Signed-off-by: Miles Adkins <miles.adkins@snowflake.com>
1 parent 9a3fd98 commit 5abed74

File tree

7 files changed

+122
-102
lines changed

7 files changed

+122
-102
lines changed

sdk/python/feast/infra/materialization/snowflake_engine.py

+8-7
Original file line numberDiff line numberDiff line change
@@ -25,10 +25,10 @@
2525
from feast.infra.online_stores.online_store import OnlineStore
2626
from feast.infra.registry.base_registry import BaseRegistry
2727
from feast.infra.utils.snowflake.snowflake_utils import (
28+
GetSnowflakeConnection,
2829
_run_snowflake_field_mapping,
2930
assert_snowflake_feature_names,
3031
execute_snowflake_statement,
31-
get_snowflake_conn,
3232
get_snowflake_online_store_path,
3333
package_snowpark_zip,
3434
)
@@ -121,7 +121,7 @@ def update(
121121
):
122122
stage_context = f'"{self.repo_config.batch_engine.database}"."{self.repo_config.batch_engine.schema_}"'
123123
stage_path = f'{stage_context}."feast_{project}"'
124-
with get_snowflake_conn(self.repo_config.batch_engine) as conn:
124+
with GetSnowflakeConnection(self.repo_config.batch_engine) as conn:
125125
query = f"SHOW STAGES IN {stage_context}"
126126
cursor = execute_snowflake_statement(conn, query)
127127
stage_list = pd.DataFrame(
@@ -173,7 +173,7 @@ def teardown_infra(
173173
):
174174

175175
stage_path = f'"{self.repo_config.batch_engine.database}"."{self.repo_config.batch_engine.schema_}"."feast_{project}"'
176-
with get_snowflake_conn(self.repo_config.batch_engine) as conn:
176+
with GetSnowflakeConnection(self.repo_config.batch_engine) as conn:
177177
query = f"DROP STAGE IF EXISTS {stage_path}"
178178
execute_snowflake_statement(conn, query)
179179

@@ -263,10 +263,11 @@ def _materialize_one(
263263

264264
# Lets check and see if we can skip this query, because the table hasnt changed
265265
# since before the start date of this query
266-
with get_snowflake_conn(self.repo_config.offline_store) as conn:
266+
with GetSnowflakeConnection(self.repo_config.offline_store) as conn:
267267
query = f"""SELECT SYSTEM$LAST_CHANGE_COMMIT_TIME('{feature_view.batch_source.get_table_query_string()}') AS last_commit_change_time"""
268268
last_commit_change_time = (
269-
conn.cursor().execute(query).fetchall()[0][0] / 1_000_000_000
269+
execute_snowflake_statement(conn, query).fetchall()[0][0]
270+
/ 1_000_000_000
270271
)
271272
if last_commit_change_time < start_date.astimezone(tz=utc).timestamp():
272273
return SnowflakeMaterializationJob(
@@ -432,7 +433,7 @@ def materialize_to_snowflake_online_store(
432433
)
433434
"""
434435

435-
with get_snowflake_conn(repo_config.batch_engine) as conn:
436+
with GetSnowflakeConnection(repo_config.batch_engine) as conn:
436437
query_id = execute_snowflake_statement(conn, query).sfqid
437438

438439
click.echo(
@@ -450,7 +451,7 @@ def materialize_to_external_online_store(
450451

451452
feature_names = [feature.name for feature in feature_view.features]
452453

453-
with get_snowflake_conn(repo_config.batch_engine) as conn:
454+
with GetSnowflakeConnection(repo_config.batch_engine) as conn:
454455
query = materialization_sql
455456
cursor = execute_snowflake_statement(conn, query)
456457
for i, df in enumerate(cursor.fetch_pandas_batches()):

sdk/python/feast/infra/offline_stores/snowflake.py

+14-10
Original file line numberDiff line numberDiff line change
@@ -48,8 +48,8 @@
4848
)
4949
from feast.infra.registry.base_registry import BaseRegistry
5050
from feast.infra.utils.snowflake.snowflake_utils import (
51+
GetSnowflakeConnection,
5152
execute_snowflake_statement,
52-
get_snowflake_conn,
5353
write_pandas,
5454
write_parquet,
5555
)
@@ -74,13 +74,13 @@ class SnowflakeOfflineStoreConfig(FeastConfigBaseModel):
7474
"""Offline store config for Snowflake"""
7575

7676
type: Literal["snowflake.offline"] = "snowflake.offline"
77-
""" Offline store type selector"""
77+
""" Offline store type selector """
7878

7979
config_path: Optional[str] = os.path.expanduser("~/.snowsql/config")
8080
""" Snowflake config path -- absolute path required (Cant use ~)"""
8181

8282
account: Optional[str] = None
83-
""" Snowflake deployment identifier -- drop .snowflakecomputing.com"""
83+
""" Snowflake deployment identifier -- drop .snowflakecomputing.com """
8484

8585
user: Optional[str] = None
8686
""" Snowflake user name """
@@ -89,7 +89,7 @@ class SnowflakeOfflineStoreConfig(FeastConfigBaseModel):
8989
""" Snowflake password """
9090

9191
role: Optional[str] = None
92-
""" Snowflake role name"""
92+
""" Snowflake role name """
9393

9494
warehouse: Optional[str] = None
9595
""" Snowflake warehouse name """
@@ -155,7 +155,8 @@ def pull_latest_from_table_or_query(
155155
if data_source.snowflake_options.warehouse:
156156
config.offline_store.warehouse = data_source.snowflake_options.warehouse
157157

158-
snowflake_conn = get_snowflake_conn(config.offline_store)
158+
with GetSnowflakeConnection(config.offline_store) as conn:
159+
snowflake_conn = conn
159160

160161
start_date = start_date.astimezone(tz=utc)
161162
end_date = end_date.astimezone(tz=utc)
@@ -208,7 +209,8 @@ def pull_all_from_table_or_query(
208209
if data_source.snowflake_options.warehouse:
209210
config.offline_store.warehouse = data_source.snowflake_options.warehouse
210211

211-
snowflake_conn = get_snowflake_conn(config.offline_store)
212+
with GetSnowflakeConnection(config.offline_store) as conn:
213+
snowflake_conn = conn
212214

213215
start_date = start_date.astimezone(tz=utc)
214216
end_date = end_date.astimezone(tz=utc)
@@ -241,7 +243,8 @@ def get_historical_features(
241243
for fv in feature_views:
242244
assert isinstance(fv.batch_source, SnowflakeSource)
243245

244-
snowflake_conn = get_snowflake_conn(config.offline_store)
246+
with GetSnowflakeConnection(config.offline_store) as conn:
247+
snowflake_conn = conn
245248

246249
entity_schema = _get_entity_schema(entity_df, snowflake_conn, config)
247250

@@ -319,7 +322,8 @@ def write_logged_features(
319322
):
320323
assert isinstance(logging_config.destination, SnowflakeLoggingDestination)
321324

322-
snowflake_conn = get_snowflake_conn(config.offline_store)
325+
with GetSnowflakeConnection(config.offline_store) as conn:
326+
snowflake_conn = conn
323327

324328
if isinstance(data, Path):
325329
write_parquet(
@@ -359,7 +363,8 @@ def offline_write_batch(
359363
if table.schema != pa_schema:
360364
table = table.cast(pa_schema)
361365

362-
snowflake_conn = get_snowflake_conn(config.offline_store)
366+
with GetSnowflakeConnection(config.offline_store) as conn:
367+
snowflake_conn = conn
363368

364369
write_pandas(
365370
snowflake_conn,
@@ -427,7 +432,6 @@ def _to_arrow_internal(self, timeout: Optional[int] = None) -> pyarrow.Table:
427432
).fetch_arrow_all()
428433

429434
if pa_table:
430-
431435
return pa_table
432436
else:
433437
empty_result = execute_snowflake_statement(self.snowflake_conn, query)

sdk/python/feast/infra/offline_stores/snowflake_source.py

+3-3
Original file line numberDiff line numberDiff line change
@@ -213,13 +213,13 @@ def get_table_column_names_and_types(
213213
"""
214214
from feast.infra.offline_stores.snowflake import SnowflakeOfflineStoreConfig
215215
from feast.infra.utils.snowflake.snowflake_utils import (
216+
GetSnowflakeConnection,
216217
execute_snowflake_statement,
217-
get_snowflake_conn,
218218
)
219219

220220
assert isinstance(config.offline_store, SnowflakeOfflineStoreConfig)
221221

222-
with get_snowflake_conn(config.offline_store) as conn:
222+
with GetSnowflakeConnection(config.offline_store) as conn:
223223
query = f"SELECT * FROM {self.get_table_query_string()} LIMIT 5"
224224
cursor = execute_snowflake_statement(conn, query)
225225

@@ -250,7 +250,7 @@ def get_table_column_names_and_types(
250250
else:
251251
column = row["column_name"]
252252

253-
with get_snowflake_conn(config.offline_store) as conn:
253+
with GetSnowflakeConnection(config.offline_store) as conn:
254254
query = f'SELECT MAX("{column}") AS "{column}" FROM {self.get_table_query_string()}'
255255
result = execute_snowflake_statement(
256256
conn, query

sdk/python/feast/infra/online_stores/snowflake.py

+8-8
Original file line numberDiff line numberDiff line change
@@ -13,8 +13,8 @@
1313
from feast.infra.key_encoding_utils import serialize_entity_key
1414
from feast.infra.online_stores.online_store import OnlineStore
1515
from feast.infra.utils.snowflake.snowflake_utils import (
16+
GetSnowflakeConnection,
1617
execute_snowflake_statement,
17-
get_snowflake_conn,
1818
get_snowflake_online_store_path,
1919
write_pandas_binary,
2020
)
@@ -29,13 +29,13 @@ class SnowflakeOnlineStoreConfig(FeastConfigBaseModel):
2929
"""Online store config for Snowflake"""
3030

3131
type: Literal["snowflake.online"] = "snowflake.online"
32-
""" Online store type selector"""
32+
""" Online store type selector """
3333

3434
config_path: Optional[str] = os.path.expanduser("~/.snowsql/config")
3535
""" Snowflake config path -- absolute path required (Can't use ~)"""
3636

3737
account: Optional[str] = None
38-
""" Snowflake deployment identifier -- drop .snowflakecomputing.com"""
38+
""" Snowflake deployment identifier -- drop .snowflakecomputing.com """
3939

4040
user: Optional[str] = None
4141
""" Snowflake user name """
@@ -44,7 +44,7 @@ class SnowflakeOnlineStoreConfig(FeastConfigBaseModel):
4444
""" Snowflake password """
4545

4646
role: Optional[str] = None
47-
""" Snowflake role name"""
47+
""" Snowflake role name """
4848

4949
warehouse: Optional[str] = None
5050
""" Snowflake warehouse name """
@@ -114,7 +114,7 @@ def online_write_batch(
114114

115115
# This combines both the data upload plus the overwrite in the same transaction
116116
online_path = get_snowflake_online_store_path(config, table)
117-
with get_snowflake_conn(config.online_store, autocommit=False) as conn:
117+
with GetSnowflakeConnection(config.online_store, autocommit=False) as conn:
118118
write_pandas_binary(
119119
conn,
120120
agg_df,
@@ -178,7 +178,7 @@ def online_read(
178178
)
179179

180180
online_path = get_snowflake_online_store_path(config, table)
181-
with get_snowflake_conn(config.online_store) as conn:
181+
with GetSnowflakeConnection(config.online_store) as conn:
182182
query = f"""
183183
SELECT
184184
"entity_key", "feature_name", "value", "event_ts"
@@ -220,7 +220,7 @@ def update(
220220
):
221221
assert isinstance(config.online_store, SnowflakeOnlineStoreConfig)
222222

223-
with get_snowflake_conn(config.online_store) as conn:
223+
with GetSnowflakeConnection(config.online_store) as conn:
224224
for table in tables_to_keep:
225225
online_path = get_snowflake_online_store_path(config, table)
226226
query = f"""
@@ -248,7 +248,7 @@ def teardown(
248248
):
249249
assert isinstance(config.online_store, SnowflakeOnlineStoreConfig)
250250

251-
with get_snowflake_conn(config.online_store) as conn:
251+
with GetSnowflakeConnection(config.online_store) as conn:
252252
for table in tables:
253253
online_path = get_snowflake_online_store_path(config, table)
254254
query = f'DROP TABLE IF EXISTS {online_path}."[online-transient] {config.project}_{table.name}"'

sdk/python/feast/infra/registry/snowflake.py

+14-14
Original file line numberDiff line numberDiff line change
@@ -28,8 +28,8 @@
2828
from feast.infra.registry import proto_registry_utils
2929
from feast.infra.registry.base_registry import BaseRegistry
3030
from feast.infra.utils.snowflake.snowflake_utils import (
31+
GetSnowflakeConnection,
3132
execute_snowflake_statement,
32-
get_snowflake_conn,
3333
)
3434
from feast.on_demand_feature_view import OnDemandFeatureView
3535
from feast.project_metadata import ProjectMetadata
@@ -121,7 +121,7 @@ def __init__(
121121
f'"{self.registry_config.database}"."{self.registry_config.schema_}"'
122122
)
123123

124-
with get_snowflake_conn(self.registry_config) as conn:
124+
with GetSnowflakeConnection(self.registry_config) as conn:
125125
sql_function_file = f"{os.path.dirname(feast.__file__)}/infra/utils/snowflake/registry/snowflake_table_creation.sql"
126126
with open(sql_function_file, "r") as file:
127127
sqlFile = file.read()
@@ -177,7 +177,7 @@ def _refresh_cached_registry_if_necessary(self):
177177
self.refresh()
178178

179179
def teardown(self):
180-
with get_snowflake_conn(self.registry_config) as conn:
180+
with GetSnowflakeConnection(self.registry_config) as conn:
181181
sql_function_file = f"{os.path.dirname(feast.__file__)}/infra/utils/snowflake/registry/snowflake_table_deletion.sql"
182182
with open(sql_function_file, "r") as file:
183183
sqlFile = file.read()
@@ -284,7 +284,7 @@ def _apply_object(
284284
if hasattr(obj, "last_updated_timestamp"):
285285
obj.last_updated_timestamp = update_datetime
286286

287-
with get_snowflake_conn(self.registry_config) as conn:
287+
with GetSnowflakeConnection(self.registry_config) as conn:
288288
query = f"""
289289
SELECT
290290
project_id
@@ -405,7 +405,7 @@ def _delete_object(
405405
id_field_name: str,
406406
not_found_exception: Optional[Callable],
407407
):
408-
with get_snowflake_conn(self.registry_config) as conn:
408+
with GetSnowflakeConnection(self.registry_config) as conn:
409409
query = f"""
410410
DELETE FROM {self.registry_path}."{table}"
411411
WHERE
@@ -616,7 +616,7 @@ def _get_object(
616616
not_found_exception: Optional[Callable],
617617
):
618618
self._maybe_init_project_metadata(project)
619-
with get_snowflake_conn(self.registry_config) as conn:
619+
with GetSnowflakeConnection(self.registry_config) as conn:
620620
query = f"""
621621
SELECT
622622
{proto_field_name}
@@ -776,7 +776,7 @@ def _list_objects(
776776
proto_field_name: str,
777777
):
778778
self._maybe_init_project_metadata(project)
779-
with get_snowflake_conn(self.registry_config) as conn:
779+
with GetSnowflakeConnection(self.registry_config) as conn:
780780
query = f"""
781781
SELECT
782782
{proto_field_name}
@@ -839,7 +839,7 @@ def list_project_metadata(
839839
return proto_registry_utils.list_project_metadata(
840840
self.cached_registry_proto, project
841841
)
842-
with get_snowflake_conn(self.registry_config) as conn:
842+
with GetSnowflakeConnection(self.registry_config) as conn:
843843
query = f"""
844844
SELECT
845845
metadata_key,
@@ -869,7 +869,7 @@ def apply_user_metadata(
869869
):
870870
fv_table_str = self._infer_fv_table(feature_view)
871871
fv_column_name = fv_table_str[:-1].lower()
872-
with get_snowflake_conn(self.registry_config) as conn:
872+
with GetSnowflakeConnection(self.registry_config) as conn:
873873
query = f"""
874874
SELECT
875875
project_id
@@ -905,7 +905,7 @@ def get_user_metadata(
905905
) -> Optional[bytes]:
906906
fv_table_str = self._infer_fv_table(feature_view)
907907
fv_column_name = fv_table_str[:-1].lower()
908-
with get_snowflake_conn(self.registry_config) as conn:
908+
with GetSnowflakeConnection(self.registry_config) as conn:
909909
query = f"""
910910
SELECT
911911
user_metadata
@@ -971,7 +971,7 @@ def _get_all_projects(self) -> Set[str]:
971971
"STREAM_FEATURE_VIEWS",
972972
]
973973

974-
with get_snowflake_conn(self.registry_config) as conn:
974+
with GetSnowflakeConnection(self.registry_config) as conn:
975975
for table in base_tables:
976976
query = (
977977
f'SELECT DISTINCT project_id FROM {self.registry_path}."{table}"'
@@ -984,7 +984,7 @@ def _get_all_projects(self) -> Set[str]:
984984
return projects
985985

986986
def _get_last_updated_metadata(self, project: str):
987-
with get_snowflake_conn(self.registry_config) as conn:
987+
with GetSnowflakeConnection(self.registry_config) as conn:
988988
query = f"""
989989
SELECT
990990
metadata_value
@@ -1029,7 +1029,7 @@ def _infer_fv_table(self, feature_view) -> str:
10291029
return table
10301030

10311031
def _maybe_init_project_metadata(self, project):
1032-
with get_snowflake_conn(self.registry_config) as conn:
1032+
with GetSnowflakeConnection(self.registry_config) as conn:
10331033
query = f"""
10341034
SELECT
10351035
metadata_value
@@ -1056,7 +1056,7 @@ def _maybe_init_project_metadata(self, project):
10561056
usage.set_current_project_uuid(new_project_uuid)
10571057

10581058
def _set_last_updated_metadata(self, last_updated: datetime, project: str):
1059-
with get_snowflake_conn(self.registry_config) as conn:
1059+
with GetSnowflakeConnection(self.registry_config) as conn:
10601060
query = f"""
10611061
SELECT
10621062
project_id

0 commit comments

Comments
 (0)