Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix snow git execute for Python files #1666

Merged
merged 5 commits into from
Oct 16, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion RELEASE-NOTES.md
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@

## Fixes and improvements
* Fixed a bug that would cause the `deploy_root`, `bundle_root`, and `generated_root` directories to be created in the current working directory instead of the project root when invoking commands with the `--project` flag from a different directory.
* Align variables for `snow stage|git execute`. For Python files variables are stripped of leading and trailing quotes.

# v3.0.2

Expand All @@ -35,7 +36,7 @@
## Fixes and improvements

* Fixed the handling empty default values for strings by `snow snowpark deploy`.
* Added log error details if the `pip` command fails.
* Added log error details if the `pip` command fails.* Fix `snow git execute` support for Python files.

# v3.0.1

Expand Down
5 changes: 4 additions & 1 deletion src/snowflake/cli/_plugins/git/commands.py
Original file line number Diff line number Diff line change
Expand Up @@ -338,7 +338,10 @@ def execute(
extension will be executed.
"""
results = GitManager().execute(
stage_path=repository_path, on_error=on_error, variables=variables
stage_path=repository_path,
on_error=on_error,
variables=variables,
requires_temporary_stage=True,
)
return CollectionResult(results)

Expand Down
102 changes: 90 additions & 12 deletions src/snowflake/cli/_plugins/stage/manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
import logging
import re
import sys
import time
from contextlib import nullcontext
from dataclasses import dataclass
from os import path
Expand Down Expand Up @@ -90,12 +91,12 @@ def get_directory_from_file_path(self, file_path: str) -> List[str]:
raise NotImplementedError

def get_full_stage_path(self, path: str):
if prefix := FQN.from_stage(self.stage).prefix:
if prefix := FQN.from_stage_path(self.stage).prefix:
return prefix + "." + path
return path

def get_standard_stage_path(self) -> str:
path = self.path
path = self.get_full_stage_path(self.path)
return f"@{path}{'/'if self.is_directory and not path.endswith('/') else ''}"

def get_standard_stage_directory_path(self) -> str:
Expand All @@ -104,6 +105,9 @@ def get_standard_stage_directory_path(self) -> str:
return path + "/"
return path

def strip_stage_prefix(self, path: str):
raise NotImplementedError


@dataclass
class DefaultStagePathParts(StagePathParts):
Expand Down Expand Up @@ -141,6 +145,13 @@ def replace_stage_prefix(self, file_path: str) -> str:
file_path_without_prefix = Path(file_path).parts[OMIT_FIRST]
return f"{stage}/{'/'.join(file_path_without_prefix)}"

def strip_stage_prefix(self, file_path: str) -> str:
if file_path.startswith("@"):
file_path = file_path[OMIT_FIRST]
if file_path.startswith(self.stage_name):
return file_path[len(self.stage_name) :]
return file_path

def add_stage_prefix(self, file_path: str) -> str:
stage = self.stage.rstrip("/")
return f"{stage}/{file_path.lstrip('/')}"
Expand Down Expand Up @@ -312,7 +323,7 @@ def copy_files(self, source_path: str, destination_path: str) -> SnowflakeCursor
)

source = source_path_parts.get_standard_stage_path()
destination = destination_path_parts.get_standard_stage_directory_path()
destination = destination_path_parts.get_standard_stage_path()
log.info("Copying files from %s to %s", source, destination)
query = f"copy files into {destination} from {source}"
return self._execute_query(query)
Expand All @@ -332,8 +343,11 @@ def remove(
quoted_stage_name = self.quote_stage_name(f"{stage_name}{path}")
return self._execute_query(f"remove {quoted_stage_name}")

def create(self, fqn: FQN, comment: Optional[str] = None) -> SnowflakeCursor:
query = f"create stage if not exists {fqn.sql_identifier}"
def create(
self, fqn: FQN, comment: Optional[str] = None, temporary: bool = False
) -> SnowflakeCursor:
temporary_str = "temporary " if temporary else ""
query = f"create {temporary_str}stage if not exists {fqn.sql_identifier}"
if comment:
query += f" comment='{comment}'"
return self._execute_query(query)
Expand All @@ -347,8 +361,16 @@ def execute(
stage_path: str,
on_error: OnErrorType,
variables: Optional[List[str]] = None,
requires_temporary_stage: bool = False,
):
stage_path_parts = self._stage_path_part_factory(stage_path)
if requires_temporary_stage:
(
stage_path_parts,
original_path_parts,
) = self._create_temporary_copy_of_stage(stage_path)
else:
stage_path_parts = self._stage_path_part_factory(stage_path)

all_files_list = self._get_files_list_from_stage(stage_path_parts)

all_files_with_stage_name_prefix = [
Expand All @@ -370,7 +392,7 @@ def execute(

parsed_variables = parse_key_value_variables(variables)
sql_variables = self._parse_execute_variables(parsed_variables)
python_variables = {str(v.key): v.value for v in parsed_variables}
python_variables = self._parse_python_variables(parsed_variables)
results = []

if any(file.endswith(".py") for file in sorted_file_path_list):
Expand All @@ -380,22 +402,62 @@ def execute(

for file_path in sorted_file_path_list:
file_stage_path = stage_path_parts.add_stage_prefix(file_path)

# For better reporting push down the information about original
# path if execution happens from temporary stage
if requires_temporary_stage:
original_path = original_path_parts.add_stage_prefix(file_path)
else:
original_path = file_stage_path

if file_path.endswith(".py"):
result = self._execute_python(
file_stage_path=file_stage_path,
on_error=on_error,
variables=python_variables,
original_file=original_path,
)
else:
result = self._call_execute_immediate(
file_stage_path=file_stage_path,
variables=sql_variables,
on_error=on_error,
original_file=original_path,
)
results.append(result)

return results

def _create_temporary_copy_of_stage(
self, stage_path: str
) -> tuple[StagePathParts, StagePathParts]:
sm = StageManager()

# Rewrite stage paths to temporary stage paths. Git paths become stage paths
original_path_parts = self._stage_path_part_factory(stage_path) # noqa: SLF001

tmp_stage_name = f"snowflake_cli_tmp_stage_{int(time.time())}"
tmp_stage = (
FQN.from_stage(tmp_stage_name).using_connection(conn=self._conn).identifier
)
stage_path_parts = sm._stage_path_part_factory( # noqa: SLF001
tmp_stage + "/" + original_path_parts.directory
)

# Create temporary stage, it will be dropped with end of session
sm.create(FQN.from_string(tmp_stage), temporary=True)

# Copy the content
self.copy_files(
source_path=original_path_parts.get_full_stage_path(
original_path_parts.stage_name
),
destination_path=stage_path_parts.get_full_stage_path(
stage_path_parts.stage_name
),
)
return stage_path_parts, original_path_parts

def _get_files_list_from_stage(
self, stage_path_parts: StagePathParts, pattern: str | None = None
) -> List[str]:
Expand Down Expand Up @@ -444,6 +506,17 @@ def _parse_execute_variables(variables: List[Variable]) -> Optional[str]:
query_parameters = [f"{v.key}=>{v.value}" for v in variables]
return f" using ({', '.join(query_parameters)})"

@staticmethod
def _parse_python_variables(variables: List[Variable]) -> Dict:
def _unwrap(s: str):
if s.startswith("'") and s.endswith("'"):
return s[1:-1]
if s.startswith('"') and s.endswith('"'):
return s[1:-1]
return s

return {str(v.key): _unwrap(v.value) for v in variables}

@staticmethod
def _success_result(file: str):
cli_console.warning(f"SUCCESS - {file}")
Expand All @@ -464,16 +537,17 @@ def _call_execute_immediate(
file_stage_path: str,
variables: Optional[str],
on_error: OnErrorType,
original_file: str,
) -> Dict:
try:
query = f"execute immediate from {self.quote_stage_name(file_stage_path)}"
if variables:
query += variables
self._execute_query(query)
return StageManager._success_result(file=file_stage_path)
return StageManager._success_result(file=original_file)
except ProgrammingError as e:
StageManager._handle_execution_exception(on_error=on_error, exception=e)
return StageManager._error_result(file=file_stage_path, msg=e.msg)
return StageManager._error_result(file=original_file, msg=e.msg)

@staticmethod
def _stage_path_part_factory(stage_path: str) -> StagePathParts:
Expand Down Expand Up @@ -566,7 +640,11 @@ def _python_execution_procedure(
return _python_execution_procedure

def _execute_python(
self, file_stage_path: str, on_error: OnErrorType, variables: Dict
self,
file_stage_path: str,
on_error: OnErrorType,
variables: Dict,
original_file: str,
):
"""
Executes Python file from stage using a Snowpark temporary procedure.
Expand All @@ -576,7 +654,7 @@ def _execute_python(

try:
self._python_exe_procedure(self.get_standard_stage_prefix(file_stage_path), variables) # type: ignore
return StageManager._success_result(file=file_stage_path)
return StageManager._success_result(file=original_file)
except SnowparkSQLException as e:
StageManager._handle_execution_exception(on_error=on_error, exception=e)
return StageManager._error_result(file=file_stage_path, msg=e.message)
return StageManager._error_result(file=original_file, msg=e.message)
6 changes: 6 additions & 0 deletions src/snowflake/cli/api/identifiers.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
from __future__ import annotations

import re
from pathlib import Path

from click import ClickException
from snowflake.cli.api.exceptions import FQNInconsistencyError, FQNNameError
Expand Down Expand Up @@ -123,6 +124,11 @@ def from_stage(cls, stage: str) -> "FQN":
name = stage[1:]
return cls.from_string(name)

@classmethod
def from_stage_path(cls, stage_path: str) -> "FQN":
stage = Path(stage_path).parts[0]
return cls.from_stage(stage)

@classmethod
def from_identifier_model_v1(cls, model: ObjectIdentifierBaseModel) -> "FQN":
"""Create an instance from object model."""
Expand Down
5 changes: 5 additions & 0 deletions tests/api/test_fqn.py
Original file line number Diff line number Diff line change
Expand Up @@ -192,3 +192,8 @@ def test_using_context(mock_ctx):
mock_ctx().connection = MagicMock(database="database_test", schema="test_schema")
fqn = FQN.from_string("name").using_context()
assert fqn.identifier == "database_test.test_schema.name"


def test_git_fqn():
fqn = FQN.from_stage_path("@git_repo/branches/main/devops/")
assert fqn.name == "git_repo"
Loading
Loading