|
| 1 | +import re |
| 2 | +from unittest.mock import ANY, MagicMock, patch |
| 3 | + |
| 4 | +import pytest |
| 5 | + |
| 6 | +from feast.infra.offline_stores.snowflake import ( |
| 7 | + SnowflakeOfflineStoreConfig, |
| 8 | + SnowflakeRetrievalJob, |
| 9 | +) |
| 10 | +from feast.infra.online_stores.sqlite import SqliteOnlineStoreConfig |
| 11 | +from feast.repo_config import RepoConfig |
| 12 | + |
| 13 | + |
| 14 | +@pytest.fixture(params=["s3", "s3gov"]) |
| 15 | +def retrieval_job(request): |
| 16 | + offline_store_config = SnowflakeOfflineStoreConfig( |
| 17 | + type="snowflake.offline", |
| 18 | + account="snow", |
| 19 | + user="snow", |
| 20 | + password="snow", |
| 21 | + role="snow", |
| 22 | + warehouse="snow", |
| 23 | + database="FEAST", |
| 24 | + schema="OFFLINE", |
| 25 | + storage_integration_name="FEAST_S3", |
| 26 | + blob_export_location=f"{request.param}://feast-snowflake-offload/export", |
| 27 | + ) |
| 28 | + retrieval_job = SnowflakeRetrievalJob( |
| 29 | + query="SELECT * FROM snowflake", |
| 30 | + snowflake_conn=MagicMock(), |
| 31 | + config=RepoConfig( |
| 32 | + registry="s3://ml-test/repo/registry.db", |
| 33 | + project="test", |
| 34 | + provider="snowflake.offline", |
| 35 | + online_store=SqliteOnlineStoreConfig(type="sqlite"), |
| 36 | + offline_store=offline_store_config, |
| 37 | + ), |
| 38 | + full_feature_names=True, |
| 39 | + on_demand_feature_views=[], |
| 40 | + ) |
| 41 | + return retrieval_job |
| 42 | + |
| 43 | + |
| 44 | +def test_to_remote_storage(retrieval_job): |
| 45 | + stored_files = ["just a path", "maybe another"] |
| 46 | + with patch.object( |
| 47 | + retrieval_job, "to_snowflake", return_value=None |
| 48 | + ) as mock_to_snowflake, patch.object( |
| 49 | + retrieval_job, "_get_file_names_from_copy_into", return_value=stored_files |
| 50 | + ) as mock_get_file_names_from_copy: |
| 51 | + assert ( |
| 52 | + retrieval_job.to_remote_storage() == stored_files |
| 53 | + ), "should return the list of files" |
| 54 | + mock_to_snowflake.assert_called_once() |
| 55 | + mock_get_file_names_from_copy.assert_called_once_with(ANY, ANY) |
| 56 | + native_path = mock_get_file_names_from_copy.call_args[0][1] |
| 57 | + assert re.match("^s3://.*", native_path), "path should be s3://*" |
0 commit comments