Skip to content

Commit

Permalink
Fix snow git execute for Python files
Browse files Browse the repository at this point in the history
  • Loading branch information
sfc-gh-turbaszek committed Oct 3, 2024
1 parent 917e5cc commit 8895ff0
Show file tree
Hide file tree
Showing 7 changed files with 86 additions and 6 deletions.
3 changes: 3 additions & 0 deletions RELEASE-NOTES.md
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,9 @@


## Fixes and improvements
* Fix `snow git execute` support for Python files.
* Align variables for `snow stage|git execute`. For Python files variables are stripped of leading and trailing quotes.


# v3.0.0

Expand Down
5 changes: 4 additions & 1 deletion src/snowflake/cli/_plugins/git/commands.py
Original file line number Diff line number Diff line change
Expand Up @@ -338,7 +338,10 @@ def execute(
extension will be executed.
"""
results = GitManager().execute(
stage_path=repository_path, on_error=on_error, variables=variables
stage_path=repository_path,
on_error=on_error,
variables=variables,
requires_temporary_stage=True,
)
return CollectionResult(results)

Expand Down
50 changes: 45 additions & 5 deletions src/snowflake/cli/_plugins/stage/manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
import logging
import re
import sys
import time
from contextlib import nullcontext
from dataclasses import dataclass
from os import path
Expand Down Expand Up @@ -90,7 +91,7 @@ def get_directory_from_file_path(self, file_path: str) -> List[str]:
raise NotImplementedError

def get_full_stage_path(self, path: str):
if prefix := FQN.from_stage(self.stage).prefix:
if prefix := FQN.from_stage_path(self.stage).prefix:
return prefix + "." + path
return path

Expand Down Expand Up @@ -332,8 +333,11 @@ def remove(
quoted_stage_name = self.quote_stage_name(f"{stage_name}{path}")
return self._execute_query(f"remove {quoted_stage_name}")

def create(self, fqn: FQN, comment: Optional[str] = None) -> SnowflakeCursor:
query = f"create stage if not exists {fqn.sql_identifier}"
def create(
self, fqn: FQN, comment: Optional[str] = None, temporary: bool = False
) -> SnowflakeCursor:
temporary_str = "temporary" if temporary else ""
query = f"create {temporary_str} stage if not exists {fqn.sql_identifier}"
if comment:
query += f" comment='{comment}'"
return self._execute_query(query)
Expand All @@ -347,8 +351,13 @@ def execute(
stage_path: str,
on_error: OnErrorType,
variables: Optional[List[str]] = None,
requires_temporary_stage: bool = False,
):
stage_path_parts = self._stage_path_part_factory(stage_path)
if requires_temporary_stage:
stage_path_parts = self._create_temporary_copy_of_stage(stage_path)
else:
stage_path_parts = self._stage_path_part_factory(stage_path)

all_files_list = self._get_files_list_from_stage(stage_path_parts)

all_files_with_stage_name_prefix = [
Expand All @@ -370,7 +379,7 @@ def execute(

parsed_variables = parse_key_value_variables(variables)
sql_variables = self._parse_execute_variables(parsed_variables)
python_variables = {str(v.key): v.value for v in parsed_variables}
python_variables = self._parse_python_variables(parsed_variables)
results = []

if any(file.endswith(".py") for file in sorted_file_path_list):
Expand All @@ -396,6 +405,26 @@ def execute(

return results

def _create_temporary_copy_of_stage(self, stage_path: str) -> StagePathParts:
tmp_stage = f"snowflake_cli_tmp_stage_{int(time.time())}"
sm = StageManager()

# Create temporary stage, it will be dropped with end of session
sm.create(FQN.from_string(tmp_stage), temporary=True)

# Rewrite stage paths to temporary stage paths. Git paths become stage paths
original_path_parts = self._stage_path_part_factory(stage_path) # noqa: SLF001
stage_path_parts = sm._stage_path_part_factory( # noqa: SLF001
tmp_stage + "/" + original_path_parts.directory
)

# Copy the content
self.copy_files(
source_path=original_path_parts.stage_name,
destination_path=stage_path_parts.stage_name,
)
return stage_path_parts

def _get_files_list_from_stage(
self, stage_path_parts: StagePathParts, pattern: str | None = None
) -> List[str]:
Expand Down Expand Up @@ -444,6 +473,17 @@ def _parse_execute_variables(variables: List[Variable]) -> Optional[str]:
query_parameters = [f"{v.key}=>{v.value}" for v in variables]
return f" using ({', '.join(query_parameters)})"

@staticmethod
def _parse_python_variables(variables: List[Variable]) -> Dict:
def _unwrap(s: str):
if s.startswith("'") and s.endswith("'"):
return s[1:-1]
if s.startswith('"') and s.endswith('"'):
return s[1:-1]
return s

return {str(v.key): _unwrap(v.value) for v in variables}

@staticmethod
def _success_result(file: str):
cli_console.warning(f"SUCCESS - {file}")
Expand Down
6 changes: 6 additions & 0 deletions src/snowflake/cli/api/identifiers.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
from __future__ import annotations

import re
from pathlib import Path

from click import ClickException
from snowflake.cli.api.exceptions import FQNInconsistencyError, FQNNameError
Expand Down Expand Up @@ -123,6 +124,11 @@ def from_stage(cls, stage: str) -> "FQN":
name = stage[1:]
return cls.from_string(name)

@classmethod
def from_stage_path(cls, stage_path: str) -> "FQN":
stage = Path(stage_path).parts[0]
return cls.from_stage(stage)

@classmethod
def from_identifier_model_v1(cls, model: ObjectIdentifierBaseModel) -> "FQN":
"""Create an instance from object model."""
Expand Down
5 changes: 5 additions & 0 deletions tests/api/test_fqn.py
Original file line number Diff line number Diff line change
Expand Up @@ -192,3 +192,8 @@ def test_using_context(mock_ctx):
mock_ctx().connection = MagicMock(database="database_test", schema="test_schema")
fqn = FQN.from_string("name").using_context()
assert fqn.identifier == "database_test.test_schema.name"


def test_git_fqn():
fqn = FQN.from_stage_path("@git_repo/branches/main/devops/")
assert fqn.name == "git_repo"
9 changes: 9 additions & 0 deletions tests_integration/__snapshots__/test_git.ambr
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,15 @@
}),
])
# ---
# name: test_execute_python
list([
dict({
'Error': None,
'File': '@snowflake_cli_tmp_stage_1727969525/tests_integration/test_data/projects/stage_execute/script1.py',
'Status': 'SUCCESS',
}),
])
# ---
# name: test_execute_with_name_in_pascal_case
list([
dict({
Expand Down
14 changes: 14 additions & 0 deletions tests_integration/test_git.py
Original file line number Diff line number Diff line change
Expand Up @@ -288,6 +288,20 @@ def test_execute(runner, test_database, sf_git_repository, snapshot):
assert result.json == snapshot


@pytest.mark.integration
def test_execute_python(runner, test_database, sf_git_repository, snapshot):
result = runner.invoke_with_connection_json(
[
"git",
"execute",
f"@{sf_git_repository.lower()}/branches/main/tests_integration/test_data/projects/stage_execute/script1.py",
]
)

assert result.exit_code == 0
assert result.json == snapshot


@pytest.mark.integration
def test_execute_fqn_repo(runner, test_database, sf_git_repository):
result_fqn = runner.invoke_with_connection_json(
Expand Down

0 comments on commit 8895ff0

Please sign in to comment.